monolithic kernel

Go の WaitGroup と Java の Phaser

Go の WaitGroup は、複数のスレッドに処理を振り分けて、すべての処理が完了するまで待つような用途で便利。Java で Go の WaitGroup に似たものとして、標準ライブラリに CountDownLatch がある。

package main

import (
	"fmt"
	"sync"
)

func main() {
	type Task struct {
		Id int
	}

	tasks := []Task{
		Task{1},
		Task{2},
		Task{3},
	}

	var wg sync.WaitGroup
	wg.Add(len(tasks))
	for _, task := range tasks {
		go func(task Task) {
			defer wg.Done()
			fmt.Printf("Task.Id=%v\n", task.Id)
		}(task)
	}
	wg.Wait()
}

(コード例は Kotlin)

import java.util.concurrent.CountDownLatch
import java.util.concurrent.Executors

fun main() {
  data class Task(
    val id: Int,
  )

  val tasks = listOf(
    Task(1),
    Task(2),
    Task(3),
  )

  val executorService = Executors.newCachedThreadPool()

  val countDownLatch = CountDownLatch(tasks.size)
  for (task in tasks) {
    executorService.execute {
      try {
        println("Task.id=${task.id}")
      } finally {
        countDownLatch.countDown()
      }
    }
  }
  countDownLatch.await()
}

もっと簡単に Parallel Stream でもいける。

fun main() {
  data class Task(
    val id: Int,
  )

  val tasks = listOf(
    Task(1),
    Task(2),
    Task(3),
  )

  tasks.stream().parallel().forEach { task ->
    println("Task.id=${task.id}")
  }
}

しかし、Java におけるいずれのやり方も、以下の WaitGroup の利用ケースのように、処理すべき要素の数が未知である場合には対応できない。Parallel Stream は一見できそうに見えるが、内部で ForkJoinPool というクラスを使っている都合上、事前に要素数がわかっている必要があるようだった。

package main

import (
	"fmt"
	"sync"
)

func main() {
	type Task struct {
		Id int
	}

	tasks := make(chan Task)
	go func() {
		defer close(tasks)

		// 実際には Task の数は可変
		tasks <- Task{1}
		tasks <- Task{2}
		tasks <- Task{3}
	}()

	var wg sync.WaitGroup
	for task := range tasks {
		wg.Add(1)
		go func(task Task) {
			defer wg.Done()
			fmt.Printf("Task.Id=%v\n", task.Id)
		}(task)
	}
	wg.Wait()
}

例だと tasks は3要素なので、先にチャネルからすべての要素を受け取ってしまえばよいと思うかもしれない。しかし、ここにメモリに乗り切らないような量で要素数も未知のデータが流れてくる場合には、順次読み進めながら task の処理をディスパッチしていくようにしたい。

調べてみたところ、このような場合に使えるものとして Phaser というクラスがあることを知った。

import java.util.concurrent.Executors
import java.util.concurrent.Phaser

fun main() {
  data class Task(
    val id: Int,
  )

  val tasks = sequence {
    // 実際には Task の数は可変
    yield(Task(1))
    yield(Task(2))
    yield(Task(3))
  }

  val executorService = Executors.newCachedThreadPool()

  val phaser = Phaser(1)
  for (task in tasks) {
    phaser.register()
    executorService.execute {
      try {
        println("Task.id=${task.id}")
      } finally {
        phaser.arrive()
      }
    }
  }
  phaser.arriveAndAwaitAdvance()
}

registerAddarriveDonearriveAndAwaitAdvanceWait にそれぞれ対応しており、Go で WaitGroup を使った場合と似たような使い勝手を実現できている。Java 的にはもうちょっとラップしてあげるほうがらしい気はするが。


Related articles