Erlang: Funs (Part 1).

In the previous article we learned what predicates are. You could define a predicate and combine it with the lists:filter to filter out some elements from the given list. That's all good but in most cases it's kind of tedious and verbose to define a predicate as a regular function: you often just need a tiny piece of code that's used only once across your project.

This is when Funs could be used. Let me show you an example.

1> lists:filter(fun(X) -> X > 5 end, [4, 5, 6, 7, 8]).

Here, instead of defining the predicate somewhere else, we substitute the logic directly to the filter function by using Funs. The syntax of Funs is similar to regular function declaration, you could have clauses, guards, and define local variables. The difference is that you don't specify function's name in this case (though, it's possible and could be used for declaring recursive Funs, I'll provide you with an example in a moment), and that you use keywords fun and end which you specify only once (not per clause!) at the very beginning and at the very end, respectively.

Just as with the function references we could assign Funs to variables and call them later. Let's try it out.

1> Double = fun(X) -> X * 2 end.
2> Double(5).

It's possible to give a fun a name so that you could implement recursive algorithms. Though, the name can't be an atom in this case. Instead, it's does look like a variable. Here is an example.

1> F = fun Fib(0) -> 0; Fib(1) -> 1; Fib(N) -> Fib(N - 1) + Fib(N - 2) end.
2> F(2).
3> F(3).
4> F(15).

This function returns you a fibonacci number. The only argument here is a position of the number in the fibonacci sequence (please note the counting starts from 0). The first two positions (0 and 1) are base cases, while other cases could be thought as a recursive and analytical destructuring to smaller pieces where the fibonacci number at the given position is a sum of two previous fibonacci numbers.

While being a good example of how to implement recursion in Funs, this particular example brings us a new concept: this recursion here is actually a multiple recursion.

Note how we gave it a variable-like name Fib. This variable is a reference to self, which is then could be used as any other regular function reference. Let's implement a nullary function that returns a reference to self.

1> F = fun Test() -> Test end.
2> F().

In some languages you don't have such abilities to refer to self from withing the body. In this case you could involve an abstraction called Y combinator discovered by Haskell Curry.

1> Fib = fun(_, 0) -> 0; (_, 1) -> 1; (F, N) -> F(F, N - 1) + F(F, N - 2) end.
2> Fib(Fib, 6).

To be precisely correct this is not the Y combinator (which is usually implemented as a higher-order function called fix) but it's a form of it (as it has many). In the example above we can't refer to the body directly, so the only option is to pass a reference to self using arguments: in our case it's F. Later then the function is created and the reference is associated with the variable called Fib. So we can now use it: first we call the function, second we pass the reference to itself as a very first parameter.

Funs could be used in a variety of scenarios. They're also a key to some new and unexpected abstractions. In fact, Funs (also know as lambda functions or anonymous functions) are a whole new world that's so huge and so powerful it's even hard to imagine. You could use Funs to create your own control flow abstractions (like if / case / switch / for-loop or even switch from eager to lazy evaluation strategy and more, this will be discussed in the part 3 of this series of articles); you could use them for optimizations (for instance, you could move some work like reading settings in the initialization stage, then assemble a fun right there removing unnecessary parts and put that fun to some long-running process, potentially saving lots of computer time on redundant evaluations that were previously removed); and finally Funs are Data. Yes, that's right. Funs are actually Data. And by saying that I mean Funs could not only be thought as a fundamental type in Erlang, but they themselves create new types. This topic is an advanced one and is a subject of the part 2 of this series of articles.

Now I would like to introduce you 3 very common functions on lists. These are: filter, map and fold (folds in fact). We have seen the filter function already. Now let's play with the map and fold functions and see how they're connected to each other.

Map is a function that transforms a given list of elements to a list of the same size where each new element is a result of F(x). It might sound a little confusing, but the function is way simpler when you see it in action. Let's first implement it.


map(_, []) -> [];
map(F, [H|T]) -> [F(H)|map(F, T)].

Let's test it.

1> lesson5:map(fun(X) -> X + 1 end, [1, 2, 3]).
2> lesson5:map(fun(X) -> X * X end, [1, 2, 3]).

Now the fold function. Or as it was previously mentioned, folds. There are 2 of them: left and right. In the case of lists these are the functions that fold lists either to left or to right by applying some binary function F (sometimes called as the aggregation function) to the accumulator and the head of the list producing some new accumulator. The result of the function is the accumulator you get when there are no elements left in the given list.

Let's first implement it and then I'll show you some examples so you get a better understanding of how it works.

foldr(_, Acc, []) -> Acc;
foldr(F, Acc, [H|T]) -> F(H, foldr(F, Acc, T)).

That's it, the function accepts some binary function F that takes one element from the list and the current accumulator (Acc) and produces a new accumulator. This happens from right to left because to evaluate the current F call BEAM would need to expand the nested foldr call first. It happens recursively until reaching the base case.

1> lesson5:foldr(fun(X, Y) -> X + Y end, 0, [1, 2, 3]).
2> lesson5:foldr(fun erlang:'+'/2, 0, [1, 2, 3]).

Both of the above lines of code are equivalent and produce the same result. Check the erlang module to find more operators. Now let's implement the foldl function.

foldl(_, Acc, []) -> Acc;
foldl(F, Acc, [H|T]) -> foldl(F, F(H, Acc), T).

The implementation might look very similar to the foldr function. Though, it's now something else, completely different. The second clause forces BEAM to compute F(H, Acc) before moving forward. That results in a different evaluation flow. Let's illustrate it.

1> lesson5:foldr(fun erlang:'/'/2, 1, [2, 3, 4, 5]).
2> lesson5:foldl(fun erlang:'/'/2, 1, [2, 3, 4, 5]).

The first call is expanded to 2 / (3 / (4 / (5 / 1))) while the second one produces 5 / (4 / (3 / (2 / 1))). That's because foldr associates its folding function from right to left, while foldl associates it from left to right.

One important note before we move forward. If you take a look at the foldl function you would see the very last call is foldl itself. That's because at the moment of calling foldl its arguments are already known and there are no other functions to evaluate. This case called tail-recursive and could be optimized using the trick known as a tail call. Tail-recursive functions produce iterative processes and could be represented by iteration while functions like foldr (we call them body-recursive) produce recursive processes and are harder to optimize, in addition they usually require more memory. In fact, the amount of memory needed to compute body-recursive functions depends on the size of the list processed. The performance though might be almost the same so there is no need to add extra efforts on rewriting functions to tail recursion.

Also keep in mind, not all body-recursive algorithms could be converted to tail recursion. If your function does a recursive call and then works on the result analytically it might be probably not possible to rewrite it to tail recursion. You would especially notice it when working with multiple recursion mentioned in the beginning of this article.

Also, I recommend you to keep in mind that first you write a code that works, then you optimize it. Not vice versa. I encourage you to prefer readability over performance, at least for the version one of programs you write. Though, while using foldr and foldl from the lists module please prefer foldl when possible. I have plans to write an additional article to discuss some performance topics on Erlang, so we will talk about body-recursive and tail-recursive functions once again.

Now let's try another example.

1> lesson5:foldr(fun(X, Y) -> [X|Y] end, [], [1, 2, 3]).
2> lesson5:foldl(fun(X, Y) -> [X|Y] end, [], [1, 2, 3]).

If you start constructing a list adding new elements to the empty list you would get your original list as is or the original list reversed, based on which version of fold you were using. The foldl function does reverse the order of elements while foldr doesn't. This could be used to implement the reverse function.

reverse(L) -> foldl(fun(H, T) -> [H|T] end, [], L).

In fact, having just foldl and foldr functions surprisingly allows you to implement almost any algorithm on lists that require traversing. Not always in the most efficient way but still possible. For instance, it's possible to implement the Any and All functions using the fold function but it wouldn't be much efficient since Any and All utilize short-circuit expressions internally and might terminate quicker, while the fold-based version requires a traverse over the entire list. Other than that the fold functions might be really useful.

An important note: here I'm talking about Erlang and languages with eager evaluation model, in opposite, in languages with lazy evaluation model which allow you to utilize short circuiting expressions as parameters, for instance in Haskell, foldr could be fundamentally different from foldl and could be terminated quicker without the need of traversing the entire list.

Let's now rewrite the map function using folds.

map2(F, L) -> foldr(fun(H, T) -> [F(H)|T] end, [], L).

map3(F, L) -> reverse(foldl(fun(H, T) -> [F(H)|T] end, [], L)).

The map3 example is a common way to handle such cases: you process lists using the foldl function due to performance reasons and then just to reverse the result so you would get the initial order. I'll be using the foldr function though, for simplicity reasons.

The filter function is implemented similarly, but this time the list creation happens if only the head element satisfies the given predicate, otherwise the result is the unmodified version of the accumulator.

filter(P, L) -> foldr(
    fun(H, T) ->
        case P(H) of
            true -> [H|T];
            _ -> T
    end, [], L).

Please notice we can't just write fun(H, T) when P(H), that's considered illegal as guards only allow you to utilize BIFs; that's why we're using the case operator here. In addition we're gently stepping to the scoping topic here but this will be discussed in the next article.

The length function is foldl with the aggregation function that ignores its first argument and applies the successor function to its second argument.

length(L) -> foldl(fun(_, Acc) -> Acc + 1 end, 0, L).

The sum and product functions are a combination of foldl and + and * operators, respectively.

sum(L) -> foldl(fun erlang:'+'/2, 0, L).

product(L) -> foldl(fun erlang:'*'/2, 1, L).

Please notice, in the case of the length function we utilize 0 as a starting point. That's because 0 is the identity for addition. Same is happening in the implementation of the sum function, while the product function utilizes 1 as it's the identity for multiplication.

The max function is simple as well, you just combine the foldl function with the function that returns either its first argument or its second argument based on what's greater. To implement that we could involve guards.

max([H|T]) -> foldl(fun(X, Y) when X > Y -> X; (_, Y) -> Y end, H, T).

The min function is implemented in a similar way and I'm omitting it here.

If you take a look at the max function you would notice it's undefined for empty lists. So that's for non-empty lists we divide them to the head and tail and then use the head as an initial state when calling the foldl function.

We conclude this article by solving a little puzzle: find a sum of prime numbers between 1 and 1000.

factors(N) -> filter(fun(X) -> N rem X =:= 0 end, lists:seq(1, N)).

is_prime(N) -> factors(N) =:= [1, N].

primes(N) -> filter(fun is_prime/1, lists:seq(1, N)).

sum_of_primes(N) -> sum(primes(N)).

This is a naive and not really efficient way to test primality, but it works and gives you a taste of having an ability to write programs using the list functions.

The factors function returns us a list of factors for any given N. The is_prime predicate is based on the idea that if there are no distinct factors but 1 and the number itself then the number is prime. The primes function returns us a list of prime numbers on a range from 1 to N, while sum_of_primes just evaluates the sum of all prime numbers up to N. And that's it!

Funs bring us a good instrument to create new abstractions. While combined with recursion and list functions like filter, fold and map Funs make it easier to comprehend and express our programs. But that's not just it. Funs are really building blocks for even smarter abstractions and that's we're going to find out in the next article.

P.S.: if you liked this article, please share it with your friends and colleagues. And as always, don't forget to inspect and download the sources.