https://www.acmicpc.net/problem/2447

 

2447번: 별 찍기 - 10

재귀적인 패턴으로 별을 찍어 보자. N이 3의 거듭제곱(3, 9, 27, ...)이라고 할 때, 크기 N의 패턴은 N×N 정사각형 모양이다. 크기 3의 패턴은 가운데에 공백이 있고, 가운데를 제외한 모든 칸에 별이

www.acmicpc.net

 

풀이1

(3^k)꼴의 N과 3x3 패턴의 출력 형식이 주어졌을 때, 그 패턴에 맞춰 NxN 크기로 별을 출력하는 문제였다.

 

NxN은 (N/3) x (N/3) 크기의 영역 9개로 이루어져 있으며, 그 나누어진 영역은 다시 9개의 더 작은 영역으로 이루어져 있다. 그래서 문제에서 주어진 것처럼, 재귀적인 분할 정복 방식을 사용하여 문제를 해결할 수 있다.

 

9x9 = (3x3) 9개

위 그림은 N = 9일 때를 나타낸 것이다. 만약 N = 27이라면 이 그림이 9개로 이루어져 있을 것이다.

 

즉, NxN을 그리기 위해선 (N/3) x (N/3) 9개를 그려야 하고, N/3을 다시 나누다 보면 결국 3x3 영역부터 그리기 시작해서 전체를 그리게 된다.

 

참고로 영역은 배열을 사용하여,

가장 왼쪽 상단의 시작점(row, col)과 영역의 크기(size)를 통해 영역을 나타낼 수 있다.

배열, 시작점, 크기로 영역을 나타낼 수 있다

 

그러면 이제 NxN 패턴의 별을 그리는 재귀 함수를 선언하고, 그 안에서 크기가 3이 될 때까지 (N/3) x (N/3) 영역 9개 그리도록 재귀적으로 호출하면 문제를 해결할 수 있다.

위의 과정에서 영역을 인자로 받아서 3x3패턴의 별을 배열에 저장해주는 함수공백을 저장하는 함수가 추가적으로 필요하다.

 

시작점: (row, col), 크기: size인 별을 그리는 함수

 

풀이2

풀이1은 영역을 중심으로 분할해서 가장 작은 단위부터 해결하는 방식이었다. 다른 방법은 없을까 궁금해서 다른 분들의 풀이를 살펴보다가, 같은 분할 정복이지만 약간 다른 관점의 방법이 있어서 소개하려고 한다.

 

이 방법도 어느 영역에 속하는지가 중요하긴 하지만, 영역 자체가 아닌 각 점이 어느 영역에 속하는지 판단하여 해당 위치에 '*' 이나 ' '를 저장한다.  

3x3의 빈칸

우선, 기본 단위인 3x3에서 빈칸은 각 정사각형의 1행 1열에 해당한다. 좌표는 (1,1), (1, 4), (1, 7), (1, 10), (1, 13)이며,

판별식은 (i % 3 == 1) && (j % 3 == 1)이다.

 

9x9에서도 1행 1열이 공백에 해당한다

9x9 패턴에서도 공백은 3x3 단위로 봤을 때 1행 1열 영역에 해당한다.

 

즉, N x N 패턴에서 1행 1열에 있는 (3/N) x (3/N) 영역 전부가 공백에 해당한다.

그리고 모든 크기로 확장한 판별식은 {(i / size) % 3} == 1 && {(j / size) % 3} == 1이 된다.

i / size 은 그 점이 3*size x 3*size 단위로 봤을 때, 몇 번째 행의 size x size 영역인지를 나타낸다.

 

예를 들어 N = 27, size = 9, i = 11, j = 0이라면,

27x27의 일부분

 

i / (size/3) = 3이므로 위에서 네 번째 3x3 영역에 속하며,

3 % 3 = 0이므로 9x9 영역에서 0행 0열에 있는 3x3에 속한다. 즉 1행 1열이 아니기 때문에 공백이 아닌걸 알 수 있다.

 

그리고 이것을 재귀적 코드로 나타내면 아래와 같다.

 

코드1

import java.io.BufferedReader
import java.io.InputStreamReader

fun main() = with(BufferedReader(InputStreamReader(System.`in`))) {
    val n = readLine().toInt()
    val stars = Array(n) { CharArray(n) }

    setStars(stars, 0, 0, n) // NxN 패턴 별을 그린다

    val answer = StringBuilder()
    stars.forEach { charArray ->
        charArray.forEach {
            answer.append(it)
        }

        answer.appendLine()
    }

    print(answer)
}

// 왼쪽 위의 점인 stars[row][col]에서부터 size * size만큼 별 저장 
fun setStars(stars: Array<CharArray>, row: Int, col: Int, size: Int) {

    // 크기가 3일 때까지 나누어 들어가다가, 3이 되면 해당 영역에 3x3 패턴을 저장하고 종료한다.
    if (size == 3) {
        set3Pattern(stars, row, col, size)
        return
    }

    val dividedSize = size / 3

    // 0행
    setStars(stars, row, col, dividedSize)
    setStars(stars, row, col + dividedSize, dividedSize)
    setStars(stars, row, col + 2*dividedSize, dividedSize)
    
    // 1행
    setStars(stars, row + dividedSize, col, dividedSize)
    setSpace(stars, row + dividedSize, col + dividedSize, dividedSize) // 공백
    setStars(stars, row + dividedSize, col + 2*dividedSize, dividedSize)

    // 2행
    setStars(stars, row + 2*dividedSize, col, dividedSize)
    setStars(stars, row + 2*dividedSize, col + dividedSize, dividedSize)
    setStars(stars, row + 2*dividedSize, col + 2*dividedSize, dividedSize)
}

// 3 * 3 패턴 별 저장
fun set3Pattern(stars: Array<CharArray>, startRow: Int, startCol: Int, size: Int) {
    for (row in startRow until startRow + size){
        for (col in startCol until startCol + size){
            stars[row][col] = '*'
        }
    }

    stars[startRow + 1][startCol + 1] = ' '
}

// 공백 저장
fun setSpace(stars: Array<CharArray>, startRow: Int, startCol: Int, size: Int) {
    for (row in startRow until startRow + size){
        for (col in startCol until startCol + size){
            stars[row][col] = ' '
        }
    }
}

 

코드2

import java.io.BufferedReader
import java.io.InputStreamReader

fun main() = with(BufferedReader(InputStreamReader(System.`in`))) {
    val n = readLine().toInt()
    val answer = StringBuilder()

    for (i in 0 until n) {
        for (j in 0 until n) {
            answer.append(getStar(i, j, n / 3)) // 위치에 해당하는 '*' or ' ' 반환
        }
        answer.appendLine()
    }

    print(answer)
}

fun getStar(row: Int, col: Int, size: Int): Char {

    // 해당 위치가 (3*size) * (3*size)의 정사각형 영역 안에서,
    // 가운데 영역인 (size * size) 정사각형 안에 포함될 경우
    if ((row / size) % 3 == 1 && (col / size) % 3 == 1){
        return ' '
    }

    return if (size == 1) '*'
    else getStar(row, col, size / 3)
}

https://www.acmicpc.net/problem/1780

 

1780번: 종이의 개수

N×N크기의 행렬로 표현되는 종이가 있다. 종이의 각 칸에는 -1, 0, 1 중 하나가 저장되어 있다. 우리는 이 행렬을 다음과 같은 규칙에 따라 적절한 크기로 자르려고 한다. 만약 종이가 모두 같은 수

www.acmicpc.net

 

풀이

NxN 행렬에서 숫자가 모두 같으면 해당 숫자를 카운트, 그렇지 않으면 크기를 9등분해서 각 행렬에 대해 같은 작업을 반복하는 문제였다. 분할 정복 or 재귀를 사용한다는 건 금방 알 수 있었지만 행렬을 나누는 게 어렵게 느껴졌다. 

 

행렬은 어떻게 표현하느냐에 따라 나누는 방법이 달라진다.

행렬은 어떻게 표현하는 게 좋을까? 나같은 경우 한 점과 길이로 구분하는 것이 가장 쉽다고 느꼈다.

행렬 표현

 

모두 정사각 행렬이기 때문에, 이렇게 시작 점과 길이를 알면 행렬을 간단하게 나타낼 수 있다.

 

그리고 점과 길이를 안다면 (점 + 길이)로 다른 행렬의 시작점을 계산하여 쉽게 분할할 수 있다.

행렬 분할

 

그러면 이제 큐or 스택과 같은 자료구조를 이용하는 반복문이나, 재귀 함수를 작성해서 답을 구할 수 있다.

 

행렬이 모두 같은 원소를 갖으면 해당 숫자를 카운트하고 분할을 종료한다. 만약 길이가 1일 경우, 원소가 하나라서 항상 같은 원소를 갖고 있므로 분할은 자동 종료된다.

 

리뷰

처음엔 왼쪽 끝의 좌표와 오른쪽 끝의 좌표를 사용해서 행렬을 구분하려고 했는데, 좌표를 구하는 연산 과정이 복잡해져서 멘붕이 왔었다..ㅋㅋㅋ

 

두 점으로 행렬을 구분하기 위해서는, 그림처럼 서로 다른 18개의 점을 구해야 한다. 그리고 다음 행렬의 좌표는 이전의 행이나 열에 +1을 해줘야 한다. 이처럼 그림 없이는 헷갈릴 수가 있는데, 머리로만 하려고 했던 게 가장 큰 문제였다.

 

적극적으로 문제를 풀어주자!!

 

나중에 멘탈을 다잡고 그림을 보며 위의 방식으로 해봤는데 결과는 시간초과 였다ㅋㅋ;;

 

 

코드1 (재귀 사용)

import java.io.BufferedReader
import java.io.InputStreamReader
import java.util.*

fun main() = with(BufferedReader(InputStreamReader(System.`in`))) {
    val n = readLine().toInt()
    val matrix = Array(n) {
        val st = StringTokenizer(readLine())
        IntArray(n) { st.nextToken().toInt() }
    }

    val countArray = IntArray(3) // -1, 0, 1을 인덱스 0,1,2로 카운트
    countPapers(matrix, 0, 0, n, countArray)
    val answer = StringBuilder().append(countArray[0]).appendLine()
        .append(countArray[1]).appendLine()
        .append(countArray[2])
    print(answer)
}

// 인자: (행렬, 시작 행, 시작 열, 길이, 카운트 배열)
fun countPapers(matrix: Array<IntArray>, row: Int, col: Int, size: Int, countArray: IntArray) {
    val result = hasSame(matrix, row, col, size) // 같은 숫자만 갖고 있으면 해당 숫자를, 아니면 2를 반환

    // 같은 숫자만 갖고 있으면, 카운트 하고 분할 종료
    if (result != 2) {
        countArray[result + 1]++
        return
    }

    val dividedSize = size / 3
    // 0행 분할
    countPapers(matrix, row, col, dividedSize, countArray) // 0행 0열 (왼쪽 위 행렬)
    countPapers(matrix, row, col + dividedSize, dividedSize, countArray) // 0행 1열 (중앙 위)
    countPapers(matrix, row, col + 2 * dividedSize, dividedSize, countArray) // 0행 2열 (오른쪽 위)

    // 1행
    countPapers(matrix, row + dividedSize, col, dividedSize, countArray)
    countPapers(matrix, row + dividedSize, col + dividedSize, dividedSize, countArray)
    countPapers(matrix, row + dividedSize, col + 2 * dividedSize, dividedSize, countArray)

    // 2행
    countPapers(matrix, row + 2 * dividedSize, col, dividedSize, countArray)
    countPapers(matrix,row + 2 * dividedSize, col + dividedSize, dividedSize, countArray)
    countPapers(matrix, row + 2 * dividedSize, col + 2 * dividedSize, dividedSize, countArray)
}

fun hasSame(matrix: Array<IntArray>, startRow: Int, startCol: Int, size: Int): Int {
    val prevPaper = matrix[startRow][startCol]

    for (row in startRow until startRow + size) {
        for (col in startCol until startCol + size) {
            if (prevPaper != matrix[row][col]) {
                return 2 // 다른 숫자를 갖고 있으면 2를 반환한다
            }
        }
    }

    return prevPaper // 같은 숫자만 갖고 있으면, 해당 숫자 반환
}

 

코드2 (Queue 사용)

import java.io.BufferedReader
import java.io.InputStreamReader
import java.util.*

data class MatrixInfo(val row: Int, val col: Int, val size: Int)

fun main() = with(BufferedReader(InputStreamReader(System.`in`))) {
    val n = readLine().toInt()
    val matrix = Array(n) {
        val st = StringTokenizer(readLine())
        IntArray(n) { st.nextToken().toInt() }
    }

    print(countPapers(matrix))
}

fun countPapers(matrix: Array<IntArray>): StringBuilder {
    val queue: Queue<MatrixInfo> = LinkedList()
    queue.offer(MatrixInfo(0, 0, matrix.size))

    val count = IntArray(3)

    while (queue.isNotEmpty()) {
        val matrixInfo = queue.poll()
        val result = hasSame(matrix, matrixInfo) // -1,0,1 -> 같은 숫자로 채워진 종이 , 2 -> 다른 숫자 포함

        if (result != 2) {
            count[result + 1]++
            continue
        }

        val row = matrixInfo.row
        val col = matrixInfo.col
        val dividedSize = matrixInfo.size / 3

        queue.offer(MatrixInfo(row, col, dividedSize)) // 0행 0열
        queue.offer(MatrixInfo(row, col + dividedSize, dividedSize)) // 0행 1열
        queue.offer(MatrixInfo(row, col + 2 * dividedSize, dividedSize)) // 0행 2열

        // 1행
        queue.offer(MatrixInfo(row + dividedSize, col, dividedSize))
        queue.offer(MatrixInfo(row + dividedSize, col + dividedSize, dividedSize))
        queue.offer(MatrixInfo(row + dividedSize, col + 2 * dividedSize, dividedSize))

        // 2행
        queue.offer(MatrixInfo(row + 2 * dividedSize, col, dividedSize))
        queue.offer(MatrixInfo(row + 2 * dividedSize, col + dividedSize, dividedSize))
        queue.offer(MatrixInfo(row + 2 * dividedSize, col + 2 * dividedSize, dividedSize))
    }

    return StringBuilder().append(count[0]).appendLine()
        .append(count[1]).appendLine()
        .append(count[2])
}

fun hasSame(matrix: Array<IntArray>, matrixInfo: MatrixInfo): Int {
    val prevPaper = matrix[matrixInfo.row][matrixInfo.col]

    for (row in matrixInfo.row until matrixInfo.row + matrixInfo.size) {
        for (col in matrixInfo.col until matrixInfo.col + matrixInfo.size) {
            if (prevPaper != matrix[row][col]) {
                return 2
            }
        }
    }

    return prevPaper
}

 

틀렸던 코드 (끝 점 2개로 행렬 구분) - 메모리 초과

import java.io.BufferedReader
import java.io.InputStreamReader
import java.util.*

// 왼쪽 위, 오른쪽 아래의 점으로 행렬을 구분한다
data class Points(val r1: Int, val c1: Int, val r2: Int, val c2: Int)

fun main() = with(BufferedReader(InputStreamReader(System.`in`))){
    val n = readLine().toInt()
    val matrix = Array(n){
        val st = StringTokenizer(readLine())
        IntArray(n) { st.nextToken().toInt() }
    }

    print(countPapers(matrix))
}

fun countPapers(matrix: Array<IntArray>): StringBuilder{
    val queue: Queue<Points> = LinkedList()
    queue.offer(Points(0,0, matrix.size-1, matrix.size-1))

    val count = IntArray(3)

    while (queue.isNotEmpty()) {
        val points = queue.poll()
        val result = hasSame(matrix, points)

        if (result != 2){
            count[result + 1]++
            continue
        }

        val r1 = points.r1; val c1 = points.c1
        val r2 = points.r2; val c2 = points.c2
        val gapR1 = r1 + (r2 - r1) / 3; val gapR2 = gapR1 + (r2 - r1) / 3 + 1 // 1/3, 2/3 지점
        val gapC1 = c1 + (c2 - c1) / 3; val gapC2 = gapC1 + (c2 - c1) / 3 + 1

        queue.offer(Points(r1, c1, gapR1, gapC1)) // 0행 0열
        queue.offer(Points(r1, gapC1 + 1, gapR1, gapC2)) // 0행 1열
        queue.offer(Points(r1, gapC2 + 1, gapR1, c2)) // 0행 2열

        queue.offer(Points(gapR1 + 1, c1, gapR2, gapC1))
        queue.offer(Points(gapR1 + 1, gapC1 + 1, gapR2, gapC2))
        queue.offer(Points(gapR1 + 1, gapC2 + 1, gapR2, c2))

        queue.offer(Points(gapR2 + 1, c1, r2, gapC1))
        queue.offer(Points(gapR2 + 1, gapC1 + 1, r2, gapC2))
        queue.offer(Points(gapR2 + 1, gapC2 + 1, r2, c2))
    }

    return StringBuilder().append(count[0]).appendLine()
            .append(count[1]).appendLine()
            .append(count[2])
}

fun hasSame(matrix: Array<IntArray>, points: Points): Int {
    val paper = matrix[points.r1][points.c1]

    for (row in points.r1..points.r2) {
        for (col in points.c1..points.c2){
            if (paper != matrix[row][col]) return 2
        }
    }

    return paper
}

https://www.acmicpc.net/problem/1654

 

1654번: 랜선 자르기

첫째 줄에는 오영식이 이미 가지고 있는 랜선의 개수 K, 그리고 필요한 랜선의 개수 N이 입력된다. K는 1이상 10,000이하의 정수이고, N은 1이상 1,000,000이하의 정수이다. 그리고 항상 K ≦ N 이다. 그

www.acmicpc.net

풀이

길이가 제각각인 k개의 랜선을 일정한 크기 단위로 잘라서 n개 이상을 만들어 내려고 할 때, 그 크기 단위의 최대 값을 구하는 문제였다.

 

주어진 k개의 랜선으로 n개를 만들지 못하는 경우는 없다고 했으니까, 최소 1cm씩 잘라내면 항상 n개 이상을 구할 수 있다. 반대로 k = n = 1 이라면, 주어진 랜선 중 가장 긴 길이 만큼 자를 수 있다.

 

랜선의 길이가 21억cm을 넘기 때문에 1cm부터 차례대로 잘라보면 시간 초과가 날 것이다. 특정 크기로 잘라서 나오는 개수와 n을 비교해서 더 길게 잘라볼지, 짧게 잘라야 하는지 결정할 수 있으므로 이진 탐색을 사용할 수 있다.

 

시간 복잡도:  O(log(2^31)) = 약 31

 

구현 방법

  1. 랜선을 입력 받는다.
  2. 최소 길이(m) = 1, 최대 길이(M) = 가장 긴 랜선(or 2^31 - 1)으로 시작해서, 이진 탐색으로 랜선을 잘라보며 정답을 구한다.
    1. 각 케이블들을 길이 (m + M) / 2로 자른 개수의 합(count)을 구한다
    2. if (count >= n) -> 더 길게 잘라도 n개가 넘는지 확인해본다 -> m = (m + M) / 2 + 1
    3. else -> 너무 길게 잘라서 모자르다 -> M = (m + M) / 2 - 1
    4. 위 과정을 m <= M일 때까지 반복한다.
  3. 이진탐색이 끝났을 때, M 값이 답이 된다.

 

코드 1 (구현 방법대로 m <= M일 때까지 탐색하는 방식)

import java.io.BufferedReader
import java.io.InputStreamReader

fun main() = with(BufferedReader(InputStreamReader(System.`in`))) {
    val input = readLine().split(" ")
    val k = input[0].toInt()
    val n = input[1].toInt()
    val cables = LongArray(k + 1) // 랜선

    repeat(k) {
        cables[it] = readLine().toLong()
    }

    print(binarySearch(cables, n))
}

fun cutCables(cables: LongArray, divisor: Long): Int {
    var count = 0L

    cables.forEach { cable ->
        count += cable / divisor
    }

    return count.toInt()
}

fun binarySearch(cables: LongArray, minimumCount: Int): Long {
    var answer = 0L
    var lower = 1L
    var upper = Integer.MAX_VALUE.toLong()

    while (lower <= upper) {
        val mid = (lower + upper) / 2  // 이 과정에서 Long 값이 나올 수 있다
        val numberOfCut = cutCables(cables, mid)

        if (numberOfCut < minimumCount) {
            upper = mid - 1
        } else {
            lower = mid + 1
            answer = mid
        }
    }

    return answer
}

 

코드 2 (m < M일 때까지 탐색하는 방식)

import java.io.BufferedReader
import java.io.InputStreamReader

fun main() = with(BufferedReader(InputStreamReader(System.`in`))) {
    val input = readLine().split(" ")
    val k = input[0].toInt()
    val n = input[1].toInt()
    val cables = IntArray(k + 1) // 랜선

    repeat(k) {
        cables[it] = readLine().toInt()
    }

    print(binarySearch(cables, n))
}

fun cutCables(cables: IntArray, divisor: Long): Long {
    var count = 0L

    cables.forEach { cable ->
        count += cable / divisor
    }

    return count
}

fun binarySearch(cables: IntArray, minimumCount: Int): Long {
    var answer = 0L
    var lower = 0L
    var upper = Integer.MAX_VALUE.toLong()
    upper += 1 // while(lower < upper) 방식에서는 최초의 upper가 답인 경우를 구할 수 없어서 1을 더해준다
    
    while (lower < upper) {
        val mid = (lower + upper) / 2
        val numberOfCut = cutCables(cables, mid)

        if (numberOfCut < minimumCount) {
            upper = mid
        } else {
            lower = mid + 1
            answer = mid
        }
    }

    return answer
}

https://www.acmicpc.net/problem/1167

 

1167번: 트리의 지름

트리가 입력으로 주어진다. 먼저 첫 번째 줄에서는 트리의 정점의 개수 V가 주어지고 (2 ≤ V ≤ 100,000)둘째 줄부터 V개의 줄에 걸쳐 간선의 정보가 다음과 같이 주어진다. 정점 번호는 1부터 V까지

www.acmicpc.net

 

풀이

트리에서 임의의 두 노드 거리 중 가장 긴 것을 구하는 문제이지만 예시랑 설명이 부족하다. 1967번(https://www.acmicpc.net/problem/1967)이랑 비슷해서 그런 거 같은데, 이 문제의 예시를 살펴보는 걸 추천한다!

 

트리의 지름을 구하는 문제는 공식처럼 증명된 방법(https://blog.myungwoo.kr/112)이 있는데 그 외의 방법을 이야기 해보려고 한다. 

 

임의의 두 노드 사이의 거리를 구하는 방법을 찾아내기 위해, 특정 노드를 포함하는 최장 거리를 살펴보면 다음과 같다.

 

 

노드1을 포함하는 가장 긴 경로는 8-4-2-(1)-3-5-9로, 7+5+3+2+11+15 = 43이 된다. 그럼 전체 트리의 지름은 무엇일까?

 

 

9-5-(3)-5-12 -> 15+11+9+10 = 45로, 루트1을 포함하지 않는 것을 알 수 있다. 그럼, 트리의 지름과 루트 노드는 관련이 없는 것 같은데, 사실은 매우 관련이 있다.

 

1967번 문제 설명에서 알 수 있듯이, 트리의 지름은 경로를 포함하는 양 끝점을 양방향으로 잡아당겼을 때의 가장 긴 거리라고도 생각할 수 있다. 이것을 루트 노드와 결합해서 생각해보면 다음과 같다.

 

"각 점을 루트로 하는 서브 트리에서, 각 자식 노드에서 leaf 노드까지 가는 경로 중 가장 긴 경로의 합"

 

그림과 함께 설명하면 다음과 같다.

 

 

노드1을 루트로 하는 트리에서, 각 자식 노드(노드2, 노드3)에서 leaf 노드로 가는 경로는 6개이다.

1 -> 2~7

1 -> 2~8

1 -> 3~9

1 -> 3~10

1 -> 3~11

1 -> 3~12

 

여기서 2에서 가는 경로 중 최장 거리 하나와 3에서 가는 최장 거리를 하나씩 선택하면, 양 방향 잡아당겼을 때 최장 거리가 된다.

 

그리고 이것을 임의로 노드의 지름이라고 하면, 각 노드그의 지름 중 최대 값이 전체 트리의 지름이 된다. 왜냐하면 트리는 사실 어떤 노드도 루트가 될 수 있기 때문이다. 

어떤 노드로 루트가 될 수 있다

 

그래서 위 그림에서의 전체 지름도 (3 -> 1~), (3 -> 5~), (3 -> 6 ~)의 경로 중 가장 긴 2개를 더한 값이 답이 된 것이다.

 

구현 방법 (DFS)

1. 트리를 입력받는다

2. 루트 -> 자식 하나의 최장 거리를 반환하는 함수를 재귀로 작성한다.

2-1) 재귀 함수 안에서 루트의 인접 노드를 대상으로 반복문을 돌린다.

2-2) 반복문 안에서 최장 거리를 구할 때마다 최대 값과 두 번째 최대 값을 갱신한다.

2-3) 모든 인접 노드를 탐색 후, 리턴하기 전에 지금까지 구한 지름과 (최대 값 + 두 번째 최대 값)을 비교해서 갱신한다.

 

 

참고한 곳 & 후기)

https://kimyunseok.tistory.com/125

 

개인적으로 재귀 구현이 어렵게 느껴질 때가 종종 있었는데, 재귀를 공부하는 데 많은 도움이 된 것 같다. 

 

코드 1

import java.io.BufferedReader
import java.io.InputStreamReader
import java.util.*
import kotlin.math.max

data class Edge(var node: Int, var distance: Int)

var answer = 0

fun main() = with(BufferedReader(InputStreamReader(System.`in`))) {
    val n = readLine().toInt()
    val tree = Array(n+1){ mutableListOf<Edge>() }

    // 트리 입력
    repeat(n){
        val st = StringTokenizer(readLine())
        val node = st.nextToken().toInt()
        var adjacentNode = st.nextToken().toInt()
        while(adjacentNode != -1){
            val distance = st.nextToken().toInt()
            tree[node].add(Edge(adjacentNode, distance))
            tree[adjacentNode].add(Edge(node, distance))

            adjacentNode = st.nextToken().toInt()
        }
    }

    val visited = BooleanArray(n+1).apply{ this[1] = true }
    getMaxLengthFrom(tree, 1, visited) // 지름을 찾는다
    print(answer)
}

// 루트에서 리프 노드까지 중 최장 거리를 구한다
fun getMaxLengthFrom(tree: Array<MutableList<Edge>>, rooteNode: Int, visited: BooleanArray): Int {
    val edgeList = tree[rooteNode]
    var firstMax = 0  // 리프 노드까지 가는 가장 긴 거리
    var secondMax = 0 // 2번째로 긴 거리

    // 인접 노드 리스트
    for (i in edgeList.indices){
        val adjacentNode = edgeList[i].node
        if (visited[adjacentNode]) {
            continue
        }

        visited[adjacentNode] = true
        
        // 인접 자식까지 거리 + 자식에서 손자까지 최장거리
        val maxLength = getMaxLengthFrom(tree, adjacentNode, visited) + edgeList[i].distance
        
        // 최대 값 갱신
        if (firstMax < maxLength){
            secondMax = firstMax
            firstMax = maxLength
        }else if (secondMax < maxLength){
            secondMax = maxLength
        }
    }

    // 각 노드를 루트로하는 (가장 긴 거리 + 2번째로 가장 긴 거리)가 지름의 후보가 된다
    answer = max(answer, firstMax + secondMax) // answer 값 갱신

    return firstMax
}

 

코드 2 (공식)

import java.io.BufferedReader
import java.io.InputStreamReader
import java.util.*

data class Node(var node: Int, var distance: Int)

fun main() = with(BufferedReader(InputStreamReader(System.`in`))) {
    val n = readLine().toInt()
    val tree = Array(n+1) { mutableListOf<Node>() }

    initTree(this, tree, n) // 트리 입력

    val distance = IntArray(n+1){ -1 }.apply{ this[1] = 0} // 노드1과의 거리 + 중복 방문 체크 역할
    findDiameter(distance, tree, 1) // 노드 1의 지름을 구한다

    var tempDiameterNode = 0 // 노드1의 최장 노드
    var tempMaxDistance = 0  // 노드1에서의 최장 거리
    for (i in 1..n) {
        if (tempMaxDistance < distance[i]){
            tempMaxDistance = distance[i]
            tempDiameterNode = i
        }
    }

    // 거리 초기화
    distance.apply {
        for (i in 1..n) this[i] = -1
        this[tempDiameterNode] = 0
    }
    findDiameter(distance, tree, tempDiameterNode) // 트리 지름 탐색
    print(distance.maxOrNull())
}

fun initTree(br: BufferedReader, tree: Array<MutableList<Node>>, size: Int){
    repeat(size) {
        val st = StringTokenizer(br.readLine())
        val node = st.nextToken().toInt()
        var toNode = st.nextToken().toInt() // 인접 노드

        while (toNode != -1) {
            val distance = st.nextToken().toInt()
            tree[node].add(Node(toNode, distance)) // 인접 리스트에 추가

            toNode = st.nextToken().toInt() // 인접 노드 or -1
        }
    }
}

fun findDiameter(distance: IntArray, tree: Array<MutableList<Node>>, node: Int) {
    val list = tree[node]

    list.forEach { nextNode ->
        if (distance[nextNode.node] == -1){
            distance[nextNode.node] = distance[node] + nextNode.distance
            findDiameter(distance, tree, nextNode.node)
        }
    }
}

+ Recent posts