All functions in OCaml are functions of one argument (Currying)

Consider a simple example of a function that needs two arguments to compute its answer:

let sum x y = x + y

It is, actually, equivalent to:

let sum = (fun x -> (fun y -> x + y ))

When you use the function, sum 10 20, the function is applied as ((sum 10) 20). First, (sum 10) gets evaluated to a function (fun y -> 10 + y), which then gets applied to 20:

    sum 10 20

--> (fun x -> (fun y -> x + y )) 10 20

--> (fun y -> 10 + y) 20

--> (10 + 20)

--> 30

This technique of representing a function of multiple arguments as a sequence of functions of one argument is called currying (named after Haskell Curry).

Currying enables partial function application

let sum x y = x + y
let add_ten = sum 10

When we applied sum to 10, we’ve got a new function add_ten. It takes a single integer argument and returns its value plus 10:

# let sum x y = x + y ;;
val sum : int -> int -> int = <fun>

# let add_ten = sum 10 ;;
val add_ten : int -> int = <fun>

# add_ten 5 ;;
- : int = 15

A more interesting example:

# List.map ;;
- : ('a -> 'b) -> 'a list -> 'b list = <fun>

# List.map (fun x -> x * x) ;;
- : int list -> int list = <fun>
let square_list = List.map (fun x -> x * x)

square_list is the function List.map with the first argument already applied. It just remains to apply it to a list of integers:

square_list [10; 20; 30]  -->  [100; 400; 900]

square_list [5; 4; 6; 3; 7]  -->  [25; 16; 36; 9; 49]

Pipeline |> operator

If we want to count the number of positive elements in a list, we can write:

let count_positive ls =
  let pos = List.filter (fun x -> x > 0) ls in
  List.length pos

or

let count_positive ls =
  List.length (List.filter (fun x -> x > 0) ls)

This is a frequent pattern, when we have to apply several functions in a row like:

h (g (f x))

Unfortunately, multiple nested parentheses often make code hard to read. There is a convenient operator |> called "pipeline" (or "pipe") that can be used in this situation:

x |> f |> g |> h

It often improves code readability:

let count_positive ls =
  ls
  |> List.filter (fun x -> x > 0)
  |> List.length

Exercises:

Consider the following definitions:

let double x = 2 * x
let twice f x = f (f x)
let quad = twice double

Use the toplevel to determine what the type of quad is. Explain how it can be that quad is not syntactically written as a function that takes an argument, and yet its type shows that it is in fact a function.

Define a function repeat f n x, which applies function f to x exactly n times:

repeat f 0 x   -->   x
repeat f 1 x   -->   f x
repeat f 2 x   -->   f (f x)
...

(Can you use the "pipeline" operator to define it?)

🐤 click to open

let rec repeat f n x =
  if n <= 0 then x else repeat f (n-1) (f x)

or using "pipeline":

let rec repeat f n x =
  if n <= 0 then x else x |> f |> repeat f (n-1)

Given a function range (a, b) that creates a list of all numbers in the range ax < b:

let rec range (a, b) =
  if a >= b then
    []
  else
    a :: range (a+1, b)

Define a function fives n that produces the list of n smallest positive integers divisible by five:

fives 7  -->  [5; 10; 15; 20; 25; 30; 35]

🐤 click to open

let fives n =
  (1, n+1) |> range |> List.map (fun x -> x * 5)

or

let fives n =
  range (1, n+1) |> List.map (fun x -> x * 5)

or

let fives n =
  List.map (fun x -> x * 5) (range (1, n+1))

Using the functions range and repeat from the exercises above, define a function powers_of x n that produces the list of n first powers of x:

powers_of 2 5  -->  [1; 2; 4; 8; 16]
powers_of 10 3  -->  [1; 10; 100; 1000]

🐤 click to open

let power x n = repeat (fun a -> a * x) n 1

let powers_of x n =
  (0, n) |> range |> List.map (power x)

or

let powers_of x n =
  let mult a = a * x in
  (0, n) |> range |> List.map (fun i -> repeat mult i 1)

Operators as functions

In OCaml, when a binary operator is put in parentheses, such as (+), it acts as a function of two arguments:

(+) 2 3     -->  5
(+) 10 200  -->  210
(-) 100 93  -->  7
(^) "cat" "dog" --> "catdog"
(@) [1;2;3] [4;5;6] --> [1;2;3;4;5;6]

The multiplication operator requires additional spaces to avoid confusion with comments (* and *):

( * ) 5 7   -->  35

Fold (generalizing computation with accumulators)

These notes on fold operations are not very exhaustive, please see the textbook (sections 4.5 - 4.9) for a more detailed discussion.

Consider two functions that compute the sum in the list and find the maximum. Both function are given an initial sum, or an initial maximum guess as the argument a:

let rec sum_list a ls =
  match ls with
  | hd :: tl -> sum_list (a + hd) tl
  | [] -> a

let rec max_list a ls =
  match ls with
  | hd :: tl -> max_list (max a hd) tl
  | [] -> a
sum_list 0 [1; 2; 50; 10]  -->  63
max_list 3 [6; 11; 7; 10]  -->  11

Are sum_list and max_list tail-recursive?

🐤 click to open

Yes, both functions are tail-recursive.

These functions work similarly: at each step, they combine the accumulator a with the head of the list hd (using + or max operation).

0jlDb3v

This computation can be generalized if we allow arbitrary combination operation f to act of a and hd.

Define a higher-order function fold_left

let rec fold_left f a ls =
  match ls with
  | hd :: tl -> fold_left f (f a hd) tl
  | [] -> a
val fold_left : ('a -> 'b -> 'a) -> 'a -> 'b list -> 'a

Then,

let sum_list a ls = fold_left (fun x y -> x + y) a ls
let max_list a ls = fold_left max a ls

Since (fun x y -> x + y) is exactly the summation operator +, we can write:

let sum_list a ls = fold_left (+) a ls

…​ and its non-tail-recursive "sibling" fold_right

let rec fold_right f ls a =
  match ls with
  | hd :: tl -> f hd (fold_right f tl a)
  | [] -> a
val fold_right : ('a -> 'b -> 'b) -> 'a list -> 'b -> 'b

Functions sum_list and max_list can be refined with fold_right:

let sum_list2 a ls = fold_right (+) ls a
let max_list2 a ls = fold_right max ls a

Comparing fold_left and fold_right

fold_right f [e1; e2; e3; e4] a = f e1 (f e2 (f e3 (f e4 a)))
fold_left f a [e1; e2; e3; e4] = f (f (f (f a e1) e2) e3) e4
ncHVtaZ

If the operation f is commutative and associative such as max or operators + or *, the fold_left and fold_right compute the same result, the computation is just carried out in a different order.

Otherwise, the results can differ, for example, string concatenation operator ^:

# fold_left (^) "a" ["1"; "2"; "3"; "4"] ;;
- : string = "a1234"

# fold_right (^) ["1"; "2"; "3"; "4"] "a" ;;
- : string = "1234a"

Module List provides function List.fold_left and List.fold_right.

Can we implement function List.map using fold functions?

🐤 click to open

let map1 f ls =
  List.fold_right (fun elem a -> f elem :: a) ls []

fold_left would produce a reversed mapped list

let map2 f ls =
  List.fold_left (fun a elem -> f elem :: a) [] ls
# map1 (fun x -> x*x) [10; 20; 30; 40] ;;
- : int list = [100; 400; 900; 1600]

# map2 (fun x -> x*x) [10; 20; 30; 40] ;;
- : int list = [1600; 900; 400; 100]

So, if one used map2, they would have to call List.rev to reverse the produced list.

Generally, in OCaml, since fold_left is tail-recursive, it is more practical to use fold_left, rather than fold_right.

Both fold functions can be employed for computation that needs to carry certain state, while iterating over the collection (list). In the sum_list example, it carried the partial sum of the elements. In the max_list, it carried the value of the largest element so far.

For example, let’s say that we want to compute the average of all positive elements in a list.

In C, we could write:

int sum = 0;  // initial values
int num = 0;

for (int i = 0; i < size; i++) {
  int elem = arr[i];

  if (elem > 0) {
    sum = sum + elem;  // update sum and num
    num = num + 1;
  }
}

double avg_positive = (double) sum / (double) num;

Note that the variables sum and num are changing over iteration, carrying updating state (sum, num). In the beginning it is (0, 0), and in the end of the loop, (sum, num) are equal to the sum of all positive elements and their number.

An equivalent OCaml using fold_left would look like:

let sum, num =
  List.fold_left
    (fun (s, n) elem ->
        if elem > 0 then
          (s + elem, n + 1)
        else
          (s, n)
    )
    (0, 0)
    ls
in

float sum /. float num

Using accumulators like these sum and num, we can achieve the same result as the imperative program without using any mutable variables.

When anonymous functions are large multi-line expressions, it is good practice to define them before they are used. So the previous example becomes:

let sum, num =
  let update (s, n) elem =
    if elem > 0 then (s + elem, n + 1) else (s, n)
  in
  List.fold_left update (0, 0) ls
in

float sum /. float num

Any imperative loop that iterates over a collection (list, array, etc.) can be rewritten as a fold function application.

Of course, the average of all positive elements can be also computed as follows:

let pos_ls = List.filter (fun x -> x > 0) ls in
let sum = List.fold_left (+) 0 pos_ls in
let num = List.length pos_ls in

float sum /. float num

It is a slightly less efficient code, but more readable.

Exercises

Implement the list function length using fold_left or fold_right

length [2; 30; 5; 7; 9; 1]  -->  6
length []  -->  0

🐤 click to open

let length ls =
  List.fold_left (fun acc e -> acc + 1) 0 ls

or

let length ls =
  List.fold_right (fun e acc -> acc + 1) ls 0

Implement the list function filter using fold_left or fold_right

filter (fun x -> x > 5) [2; 30; 5; 7; 9; 1]  -->  [30; 7; 9]

🐤 click to open

let filter f ls =
  List.fold_right
    (fun e acc -> if f e then e::acc else acc)
    ls []

or

let filter f ls =
  let ls2 =
    List.fold_left
      (fun acc e -> if f e then e::acc else acc)
      [] ls
  in
  List.rev ls2

or

let filter f ls =
  ls
  |> List.fold_left
      (fun acc e -> if f e then e::acc else acc)
      []
  |> List.rev

Print elements at even indexes in the list using fold_left or fold_right

print_even_index [1; 3; 5; 7; 9; 11]
               (* 0  1  2  3  4  5 *)

Should print:

1 5 9

🐤 click to open

let print_even_index ls =
  List.fold_left (fun i e ->
      if i mod 2 = 0 then Printf.printf "%i " e;
      i + 1
    )
    0 ls

The accumulator i is the counter: 0 1 2 3 …​ , incrementing by one on each iteration.

Implement the function slice that extracts a fragment of a list using fold_left or fold_right.

slice a b ls should return a sublist of ls, containing only the elements between the indexes ai < b. (It should act as Python’s ls[a:b] notation for list slicing.)

slice 0 3 [1; 3; 5; 7; 9; 11]  -->  [1; 3]
slice 2 5 [1; 3; 5; 7; 9; 11]  -->  [3; 5; 7; 9]
slice 5 6 [1; 3; 5; 7; 9; 11]  -->  [11]
slice 3 3 [1; 3; 5; 7; 9; 11]  -->  []

Hint: Index i can be part of the accumulator.

🐤 click to open

let slice a b ls =
  let _, subls =
    List.fold_left (fun (i, acc) e ->
        if a <= i && i < b then
          (i + 1, e :: acc)
        else
          (i + 1, acc)
      )
      (0, []) ls
  in
  List.rev subls

or

let slice a b ls =
  let _, subls =
    List.fold_right (fun e (i, acc) ->
        if a <= i && i < b then
          (i - 1, e :: acc)
        else
          (i - 1, acc)
      )
      ls (List.length ls - 1, [])
  in
  subls

In the fold_left case, the counter i starts at 0 and increments on each step, while in the fold_right case, the counter starts at List.length ls - 1 and decrements on each step.