C++14でcurry化
関数型言語では関数が第一級の値という。私もよく分かっていないが、とりあえず何をするにも関数で、関数を操作するのだそうだ。その一つの例としてカリー化があげられ、ある関数に対して引数を一部与えて束縛し、その関数を返すというものである。C++で実装してみよう。
要するに、curry(関数)としてあげると、関数は「ある引数args1を取り〈残る引数args2を取る関数〉を返す関数」であると分かる。内部で関数を生成するためにジェネリックラムダを用いた。これはC++14になって導入されたもので、
auto l = [](auto a) {};
として使える。このラムダ式lにはどのような型も入れられる。つまりtemplateが使えると言うことである。
あとはC++11からあるお馴染みのキャプチャに関数を入れてあげれば無事curryのできあがり。
template <typename F> auto curry(F f) { return [f](auto... args) { return [f, args](auto... args2) { return f(args..., args2...); }; }; }
ただこれだと問題があって、例えば引数に参照が入っているような関数では使えない。それはテンプレートでも同じ問題で、
template <typename T> void func(T t) { }
とすると、t自体が参照ではなくなってしまうからだ。この問題を解決するために完全転送を行う。使い方自体は単純で、TをT&&にして、tをstd::forward
auto curry(F&& f) { return [&f](auto&&... args) { return [&f, &args](auto&&... args2) { return f(std::forward<decltype(args)>(args)..., std::forward<decltype(args2)>(args2)...); }; }; }
と思ったがこれでも完全ではない。というのも、元の関数がもしも参照を返すような関数の時は戻り値から参照が取られてしまう。従って、関数fの戻り値と一致させるためにdecltype(auto)を入れてあげる必要がある。autoだけでは参照が取れてしまうのだ。
auto curry(F&& f) { return [&f](auto&&... args) { return [&f, &args](auto&&... args2) -> decltype(auto) { return f(std::forward<decltype(args)>(args)..., std::forward<decltype(args2)>(args2)...); }; }; }
使い方
auto l = [](int a, int b) { return a + b; }; auto f = curry(l)(10); auto v = f(5);
fはカリー化された関数で、足し算を定義してあるl及び、それをカリー化して第1引数に10を入れたものである。fもまた関数で引数を1つ持つ関数である。従って、v = 10 + 5が実行される。releaseモードで実行すればこの程度のコードなら簡単に展開されてしまう。すなわち実質10+5が実行されるに過ぎない。このカリー化はtemplateを使うので最適化がかかりやすく、実行速度が出ると思われるのでもしも有用な使い方があれば積極的に使っていきたい。ただし実用性のほどはよく分からない。
C#のLINQライクな遅延評価ライブラリを作ったので、そのときのwhereやmapに入れる関数として引数を減らしつつ柔軟性を持たせるようなコードに使えるかも知れない。
余談
実は完全転送版とそうでないものだとVisual Studio 2017 RCでは最適化の効きが異なるようだ。アセンブリコードを見ると余分なコードがのっかって+50%遅くなる。原因は分からないが、改善されることを祈る。
int main() { auto startTime = std::chrono::system_clock::now(); const int n = 1000000000; int s = 0; for (int i = 0; i < n; ++i) { auto f = func_curry(i); auto v = f(i); s += v; } auto endTime = std::chrono::system_clock::now(); auto timeSpan = endTime - startTime; std::cout << s << std::endl; std::cout << "処理時間:" << std::chrono::duration_cast<std::chrono::milliseconds>(timeSpan).count() << "[ms]" << std::endl; auto l = [](int a, int b) { return a + b; }; startTime = std::chrono::system_clock::now(); s = 0; auto f = curry(l); for (int i = 0; i < n; ++i) { auto c = f(i); auto v = c(i); s += v; } endTime = std::chrono::system_clock::now(); timeSpan = endTime - startTime; std::cout << s << std::endl; std::cout << "処理時間:" << std::chrono::duration_cast<std::chrono::milliseconds>(timeSpan).count() << "[ms]" << std::endl; system("pause"); return 0; }