三浦と窮理とブログ

自分の経験したことを検索可能にしていくブログ.誰かの役に立ってくれれば嬉しいです.

Javaで末尾再帰最適化をする方法

StreamAPIを使えばjavaでも末尾再帰最適化(Tail-Call Optimization)ができるぞという本( Javaによる関数型プログラミング ―Java 8ラムダ式とStream | Venkat Subramaniam, 株式会社プログラミングシステム社 |本 | 通販 | Amazon 。英語版のpdfが無料でネットからダウンロードできます)があったので、自分でも読みながら実装してみます。

本に書いてある内容はインターフェース(IF)も使って抽象的で良いのですが、ラムダ式初心者の私には一目で理解できるものでは無かったので具体的なことから考えていきます。

目次

1からnまでの自然数の和

よくある練習問題の、1からnまでの自然数の和を求める関数を実装します。

ループによる実装

最初はループで普通に実装してみます。

public static int sumLoop(int n) {
    int sum = 0;
    for (int i = 1; i <= n; i++) {
        sum += i;
    }
    return sum;
}

実行例

System.out.println(sumLoop(10)); // 55

関数型プログラミングを勉強していると、変数の再代入のような副作用は無くしたくなる気持ちが湧いてきます。そこでよく使われる書き方として再帰があります。

再帰による直感的な実装

すぐに思いつく再帰による実装は以下のようなものです。

public static int sumRecursion(int n) {
    if (n == 1) return 1;

    return n + sumRecursion(n - 1);
}

実行例

System.out.println(sumRecursion(10)); // 55

この関数では最後にnを足してreturnする前に、sumRecursion自身を実行してその戻り値が返ってくるのを待つため、実行元の関数の情報をスタックする必要があります。

そのため、このままでは大きなnのときにスタックオーバーフローが起きてしまいます。スタックオーバーフローが起きるくらいのnだとintの桁あふれも起きるのでほんとはBigIntegerを使わないと正しい値が得られないですが、今回はその実装は省きます。

末尾再帰による実装

再帰を使うけど実行元の関数をスタックする必要性のない関数を作りたいです。そのためには、この処理に必要な情報をすべて次の再帰呼び出しされる関数に渡せばいいと思います。 継続渡しスタイル - Wikipedia

以下のように書けるのではないでしょうか。

public static int sumTailRecursion(int n, int sum) {
    if (n == 0) return sum;

    return sumTailRecursion(n - 1, sum + n);
}

実行例

System.out.println(sumTailRecursion(10, 0)); // 55

この関数の引数には、次に足す予定の自然数と今まで足してきた数を渡します。

このように関数の最終処理が純粋に自身の関数を呼ぶだけにする書き方を末尾まつび再帰と呼びます。

この実装では再帰実行した後にはもう実行元の関数の情報は必要はないのですが、残念ながらJVMは実行元関数をスタックするようにできてしまっていて、この実装でも大きいnにおいてスタックオーバーフローが起きます。

言語によってはコンパイラが末尾再帰最適化というものを行ってくれるものがあり、そのような言語では末尾再帰で実装した関数ではスタックオーバーフローは起きないのですが、Javaでは対応していないのですね。

再帰的DTO

末尾再帰最適化の実装で本を読みながら最初に私が考えたものは、もはや末尾再帰と言えるものか分からなくなったので、ここでは個人的に本質的だなと思った概念に再帰的DTO(Data Transfer Object)という名前を付けて説明していきたいと思います。

再帰的DTOを使った実装は以下です。(Lombok使ってます。)

import lombok.AllArgsConstructor;
import lombok.Getter;

@AllArgsConstructor
@Getter
public class SumDTO {  // 再帰的DTO
    final private int number;
    final private int sum;

    // 自身の型を返す関数。
    public SumDTO getNextSumDTO() {
        return new SumDTO(number - 1, sum + number);
    }
}
public static int sumRecursionDTO(int n) {
    return Stream.iterate(new SumDTO(n, 0), SumDTO::getNextSumDTO) // Stream<SumDTO> の生成
            .filter(sumDto -> sumDto.getNumber() == 0)
            .map(SumDto::getSum)
            .findFirst()
            .get();
}

実行例

System.out.println(sumRecursionDTO(10)); // 55

SumDTOクラスが再帰的DTOです。足す予定の数numberと今まで足してきた数sumを定数フィールドに持ち、次に足す予定の数number - 1と今足した数sum + numberを渡したSumDTOインスタンスを返すgetNextSumDTO()メソッドを定義しています。

sumRecursionDTOメソッドでは、SumDTO#getNextSumDTOメソッドを用いてStream<SumDTO>を生成し、SumDTO#numberフィールドを見てfilterして足し算の最後の値を取得します。

ストリームパイプライン処理では最初の要素から1つごとに、生成操作、中間操作、終端操作までを一度にしますので、スタックするものはないです。

スタックオーバーフローが起きない実装が書けました。

再帰的関数型インターフェースと再帰的ラムダ式

上の実装では再帰的DTOにあるように、定義してるクラス自身の型を返すメソッドを使ってストリームを生成しました。この機能を関数型IFとラムダ式を使って書き換えていくことを考えます。

import java.util.stream.Stream;

// 再帰的関数型インターフェース
@FunctionalInterface
public interface RecursionSupplier<T> {
    RecursionSupplier<T> get(); // 定義しているIF自身の型を返すSAM

    // 再帰の最後に生成される要素の取得
     static <T> RecursionSupplier<T> getLast(final T value) {
        return new RecursionSupplier<T>() {
            @Override
            public RecursionSupplier<T> get() {
                throw new Error("not implemented");
            }

            @Override
            public boolean isComplete() {
                return true;
            }

            @Override
            public T result() {
                return value;
            }
        };
    }

    // 再帰の終了判定用メソッド
    default boolean isComplete() {
        return false;
    }

    // 再帰の終端要素の値取得
    default T result() {
        throw new Error("not implemented");
    }

    default T invoke() {
        return Stream.iterate(this, RecursionSupplier::get) // Stream<RecursionSupplier<Integer>>
                .filter(RecursionSupplier::isComplete)
                .findFirst()
                .get()
                .result();
    }
}
public static RecursionSupplier<Integer> sumTCO(int n, int sum) {
    if (n == 0) {
        return RecursionSupplier.getLast(sum);
    }
    return () -> sumTCO(n - 1, sum + n); // 再帰的ラムダ式
}

実行例

System.out.println(sumTCO(10, 0).invoke()); // 55

RecursionSupplier<T>という関数型IFを作りました。このIFの単一抽象メソッド(SAM)は引数を持たずRecursionSupplier<T>型インスタンスを返すメソッドです。

このIFとJavaのAPIに用意されている関数型IFのSupplier<T>との違いで重要なのはSAMの返り値です。Supplier<T>のSAMはT型インスタンスを返します。

この実装では定数フィールドすら持たず、再帰の終了判定をするためのメソッドisCompleteを用意します。

型について再帰的なSAMを持つIFを用意することで、sumTCOメソッドのように再帰的なラムダ式を書くことができます。

実行時に最初に一度だけinvokeメソッドが呼ばれ、Stream<RecursionSupplier<Integer>>を生成してパイプライン処理をしていきます。

まとめ

末尾再帰や再帰的ラムダ式について学ぶことができました。ループによる実装と比べるととても大変ですが、副作用の無い関数はいろいろメリットもあるのでこれらの方法も覚えておきたいですね。

ラムダ式が読みにくい場合は一度無名クラスで書き下すと理解しやすくなりますね。