wrongwrongな開発日記

しんまいさんの忘備録

【プログラミング】末尾再帰を使う【Kotlin】

末尾再帰とは

この記事ではtailrec funで宣言すると末尾再帰最適化がかかる関数全てを末尾再帰であるとします。

なぜ末尾再帰が必要なのか

再帰的な書き方ではコードが美しくなる反面、実行時間が遅い、スタックオーバーフローを引き起こすなど、実用上の問題があります。
一方末尾再帰最適化を行えば、コンパイル結果はforやwhileを使うのと同じになるので、実用性を保ちながら再帰で書くことが可能になります。
また、単純な再帰で書くよりもシンプルに書くことができる場合もあります。
今回はKotlinを用いて説明をしていますが、言語やコンパイラによっては末尾最適化を行なってくれるものがあります。

もう少し詳しく

個人的理解

末尾の条件に引っかからなかった時に、再帰呼び出しの結果のみを返すのが末尾再帰再帰呼び出しの結果に何か処理を行なって返すのが通常の再帰と理解しています。

実装による解説

※雑実装なので1以下を入れると結果がおかしくなります。後結果が大きくなるのでBigIntegerで計算しています。
通常の再帰関数で階乗を実装すると以下のようになると思います。

val TWO = BigInteger("2")

fun fact(n: BigInteger): BigInteger{
    if(n.compareTo(TWO) < 1) return TWO
    return n * fact(n.dec())
}

一方、末尾再帰最適化が効くように書き直した結果が以下です。

val TWO = BigInteger("2")

tailrec fun fact_t(n: BigInteger, ans:BigInteger = BigInteger.ONE): BigInteger{
    if(n.compareTo(TWO) < 1) return ans * TWO
    return fact_t(n.dec(), ans * n)
}

前者は再帰呼び出しの結果にnを掛けており、後者は再帰呼び出しの結果のみを返しています。

使ってみる

ここまでの内容をまとめると以下のようになります。
通常の実装で10000!を計算しようとすればスタックオーバーフローで落ちますが、末尾最適化を行うことで正常に計算が行えます。

import java.math.BigInteger

val TWO = BigInteger("2")

fun fact(n: BigInteger): BigInteger{
    if(n.compareTo(TWO) < 1) return TWO
    return n * fact(n.dec())
}

tailrec fun fact_t(n: BigInteger, ans:BigInteger = BigInteger.ONE): BigInteger{
    if(n.compareTo(TWO) < 1) return ans * TWO
    return fact_t(n.dec(), ans * n)
}

fun main(args: Array<String>) {
    //println(fact(BigInteger("10000"))) java.lang.StackOverflowError
    println(fact_t(BigInteger("10000")))
}
デコンパイル結果

通常実装のfactのコメントを外した上でビルドし、Javaデコンパイルした結果です。
fact_tの呼び出しが再帰ではなくなっていることが分かります。

import java.math.BigInteger;
import kotlin.Metadata;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;

@Metadata(
   mv = {1, 1, 13},
   bv = {1, 0, 3},
   k = 2,
   d1 = {"\u0000\u001c\n\u0000\n\u0002\u0018\u0002\n\u0002\b\u0007\n\u0002\u0010\u0002\n\u0000\n\u0002\u0010\u0011\n\u0002\u0010\u000e\n\u0002\b\u0002\u001a\u000e\u0010\u0004\u001a\u00020\u00012\u0006\u0010\u0005\u001a\u00020\u0001\u001a\u001b\u0010\u0006\u001a\u00020\u00012\u0006\u0010\u0005\u001a\u00020\u00012\b\b\u0002\u0010\u0007\u001a\u00020\u0001H\u0086\u0010\u001a\u0019\u0010\b\u001a\u00020\t2\f\u0010\n\u001a\b\u0012\u0004\u0012\u00020\f0\u000b¢\u0006\u0002\u0010\r\"\u0011\u0010\u0000\u001a\u00020\u0001¢\u0006\b\n\u0000\u001a\u0004\b\u0002\u0010\u0003¨\u0006\u000e"},
   d2 = {"TWO", "Ljava/math/BigInteger;", "getTWO", "()Ljava/math/BigInteger;", "fact", "n", "fact_t", "ans", "main", "", "args", "", "", "([Ljava/lang/String;)V", "FactTest"}
)
public final class FactKt {
   @NotNull
   private static final BigInteger TWO = new BigInteger("2");

   @NotNull
   public static final BigInteger getTWO() {
      return TWO;
   }

   @NotNull
   public static final BigInteger fact(@NotNull BigInteger n) {
      Intrinsics.checkParameterIsNotNull(n, "n");
      if (n.compareTo(TWO) < 1) {
         return TWO;
      } else {
         BigInteger var10000 = n.subtract(BigInteger.ONE);
         Intrinsics.checkExpressionValueIsNotNull(var10000, "this.subtract(BigInteger.ONE)");
         BigInteger var2 = fact(var10000);
         var10000 = n.multiply(var2);
         Intrinsics.checkExpressionValueIsNotNull(var10000, "this.multiply(other)");
         return var10000;
      }
   }

   @NotNull
   public static final BigInteger fact_t(@NotNull BigInteger n, @NotNull BigInteger ans) {
      while(true) {
         Intrinsics.checkParameterIsNotNull(n, "n");
         Intrinsics.checkParameterIsNotNull(ans, "ans");
         BigInteger var10000;
         if (n.compareTo(TWO) < 1) {
            BigInteger var3 = TWO;
            var10000 = ans.multiply(var3);
            Intrinsics.checkExpressionValueIsNotNull(var10000, "this.multiply(other)");
            return var10000;
         }

         var10000 = n.subtract(BigInteger.ONE);
         Intrinsics.checkExpressionValueIsNotNull(var10000, "this.subtract(BigInteger.ONE)");
         BigInteger var4 = var10000;
         var10000 = ans.multiply(n);
         Intrinsics.checkExpressionValueIsNotNull(var10000, "this.multiply(other)");
         BigInteger var5 = var10000;
         ans = var5;
         n = var4;
      }
   }

   // $FF: synthetic method
   @NotNull
   public static BigInteger fact_t$default(BigInteger var0, BigInteger var1, int var2, Object var3) {
      if ((var2 & 2) != 0) {
         BigInteger var10000 = BigInteger.ONE;
         Intrinsics.checkExpressionValueIsNotNull(var10000, "BigInteger.ONE");
         var1 = var10000;
      }

      return fact_t(var0, var1);
   }

   public static final void main(@NotNull String[] args) {
      Intrinsics.checkParameterIsNotNull(args, "args");
      BigInteger var1 = fact(new BigInteger("10000"));
      System.out.println(var1);
      var1 = fact_t$default(new BigInteger("10000"), (BigInteger)null, 2, (Object)null);
      System.out.println(var1);
   }
}