Monday, April 29, 2013

Scala Lists - Partition

Two very common list functions are filter and filterNot, where the former function returns a new list where all of the elements in the list hold true for a specified predicate and the latter function returns a new list where all of the elements in the list do not hold true for a specified predicate. For example, we can use these two functions to traverse a list of integers and generate a list of all positive integers, as well as a list of all negative integers.

    val l: List[Int] = List(-12, 2, -2, -3, 4, 5, 6)
    
    val positiveList = l.filter(x => x > 0)
    val negativeList = l.filterNot(x => x > 0)


The list function partition combines both of these operations into a single traversal of the original list and returns two lists, one list of values where the predicate holds true and another list of values where the predicate does not hold true.

val l: List[Int] = List(-12, 2, -2, -3, 4, 5, 6)

val positiveAndNegativeLists = l.partition(x => x > 0)


We can then access each list using positiveAndNegativeLists._1 and positiveAndNegativeLists._2, where the former is the list of all elements where the predicate is true and the latter contains all elements where the predicate is false.

    for (y <- positiveAndNegativeLists._1) println(y)
    for (y <- positiveAndNegativeLists._2) println(y)

Scala Lists - Map and Filter

Two very common List operations in Scala are map and filter. Let's start by looking at map. This function is used when you need to map all of the elements in the list to some over value. For example, if we have a list of integer values and we need to produce a new list where each value in the original list is multiplied by 3, we can write the following map implementation in Scala.

    val l: List[Int] = List(1, 2, 2, 3, 4, 5, 6)
    
    val newList = l.map(x => x * 3)


Another useful function is filter. Filter is like map, in that it traverses every element in the list. However, filter will return a new list where all of the elements in the original list hold true for a particular predicate/test. For example, if we wanted to filter out all odd integer values in a list, we can write the following in Scala.

    val l: List[Int] = List(1, 2, 2, 3, 4, 5, 6)
    
    val newList = l.filter(x => (x %2 == 0))


Of course, we can combine both map and filter. Here's an example.

    val l: List[Int] = List(1, 2, 2, 3, 4, 5, 6)
    
    val newList = l.map(x => x * 3).filter(x => (x % 2) == 0)


Scala Parser Combinators - Part 3

In a previous post on Scala parser combinators, I demonstrated how to use Scala's parser library to write a simple arithmetic parser. This example resulted in a parser that parsed an arithmetic expression and returned the results of the operation. In this tutorial, I'm going to expand on this parser by generating a parse tree, instead of simply evaluating the expression. For the expression "2 * 8 + 6", the parser will output the following parse tree.

      *
    /   \
   2     +
       /   \
      8     6


Notice that every node in our tree is either an expression consisting of an operator and a left and right side, or a single digit. For example, "8+6" is the addition operator (+) with a left side of 8 and a right side of 6. 8 and 6 are the other type of tree node (i.e. number). We'll start by defining these two types of nodes in our parse tree with a case class.

class Expr
case class Number(value: Int) extends Expr
case class Operator(symbol: String, left: Expr, right: Expr) extends Expr


Next, we'll define the rules of our parser. Note that if you read my previous posts on Scala parser combinators, I've simply added the "term" and "factor" class members.

class ExprParser extends RegexParsers {
  val number = "[1-9][0-9]*".r

  def expr: Parser[Int] = (number ^^ { _.toInt }) ~ opt(operator ~ expr) ^^ {
    case a ~ None => a
    case a ~ Some("*" ~ b) => a * b
    case a ~ Some("/" ~ b) => a / b
    case a ~ Some("+" ~ b) => a + b
    case a ~ Some("-" ~ b) => a - b
  }

  def operator: Parser[String] = "+" | "-" | "*" | "/"

  def term: Parser[Expr] = (factor ~ opt(operator ~ term)) ^^ {
    case a ~ None => a
    case a ~ Some(b ~ c) => Operator(b, a, c)
  }
  
  def factor: Parser[Expr] = number ^^ (n => Number(n.toInt))
}

Let's test the parser and evaluate the results.

  def main(args: Array[String]) {
    val parser = new ExprParser
    val result = parser.parseAll(parser.term, "9*8+2")

    println(result.get)
  }

The output from this test is the following parse tree.

Operator(*,Number(2),Operator(+,Number(8),Number(6)))


In order to look at this response and recognize a parse tree, notice that a tree of the form "8+6" is printed as "Operator(+, Number(8), Number(6))". This is because the "+" is the root of the node, so it's listed first. Thus, we have a subtree that looks like this.


         +
       /   \
      8     6



After applying this to the rest of our parse tree response of "Operator(*, Number(2), ...)", our resulting parse tree looks like this.

      *
    /   \
   2     +
       /   \
      8     6


Sunday, April 28, 2013

Scala Nested Functions

In Scala, we have the ability to define a function within another function. Here's an example of a function named "appliesToMoreThanHalf" that determines whether a certain predicate holds true for more than half of the values in a list. The function contains a nested function named "iter" that processes the list.

  def appliesToMoreThanHalf(s: List[Int], p: Int => Boolean): Boolean = {
    var count = 0

    def iter(startingIndex: Int): Boolean = {
      if (startingIndex < s.length) {
        if (p(s(startingIndex))) {
          count += 1
        }
        iter(startingIndex + 1)
      }

      count >= (s.length / 2 + 1)
    }

    iter(0)
  }

Here's an example of this function being called to determine if more than half of the elements in a list are even (i.e. evenly divisible by 2).

  def main(args: Array[String]) {
    val l: List[Int] = List(1, 2, 2, 3, 4, 5, 6)

    if (appliesToMoreThanHalf(l, p => (p % 2 == 0))) {
      println("the function applies to more than half of the elements")
    } else {
      println("the function does not apply to more than half of the elements")
    }
  }


Granted, there is a much simpler way to determine if more than half of the elements hold true for a particular predicate, but this serves as a demonstration of nested functions.

Scala Tail Recursion

When you write a recursive function, every call to the recursive function results in another method call being placed on to the call stack. If the stack grows too much, you'll get a stack overflow error. For example, look at the following Scala code that uses recursion to calculate the sum of all integers in a list.

  def sum(s: Seq[Int]): BigInt = {
    if (s.isEmpty) 0 else s.head + sum(s.tail)
  }


When we execute this method for very large lists, we get a stack overflow error.

Exception in thread "main" java.lang.StackOverflowError
 at scala.collection.AbstractTraversable.(Traversable.scala:105)
 at scala.collection.AbstractIterable.(Iterable.scala:54)
 at scala.collection.AbstractSeq.(Seq.scala:40)
 at scala.collection.immutable.Range.(Range.scala:44)
 at scala.collection.immutable.Range$Inclusive.(Range.scala:330)
 at scala.collection.immutable.Range$Inclusive.copy(Range.scala:333)
 at scala.collection.immutable.Range.drop(Range.scala:170)
 at scala.collection.immutable.Range.tail(Range.scala:196)
 at scala.collection.immutable.Range.tail(Range.scala:44)
        ...


In Scala, we can use tail recursion to tell the compiler to turn our recursive call into a loop to avoid a stack overflow error. To do this, simply add a "tailrec" annotation to the method call.

@tailrec def sum(s: Seq[Int]): BigInt = { if (s.isEmpty) 0 else s.head + sum(s.tail) }

However, if we add the annotation and re-run our example, we get the following compiler error.

could not optimize @tailrec annotated method sum: it contains a recursive call not in tail position


This error is the result of the Scala compiler not being able to utilize tail recursion due to the structure of our code. Why is that? Well, if we take the "else" path in our sum method, the first step is the recursive call to "sum", passing in the tail of the list. The result of the recursive call is added to the head of the list, making the addition operation the last step in the computation. In order to utilize tail recursion, we need to refactor our code to make the recursive call the last step of the computation. Here's the same algorithm, after refactoring to make the recursive call to "sum" be the last step in the computation.

  @tailrec def sum(s: Seq[Int], p: BigInt): BigInt = {
    if (s.isEmpty) p else sum(s.tail, s.head + p)
  }

  def main(args: Array[String]) {
   val result = sum(1 to 1000000, 0)
   println(result)
  }


Now, after running the example, we get a successful result.

500000500000


Note that the Scala compiler will try to use tail recursion when it can; however, it's a good practice to annotate any methods where you expect this optimization to be done (i.e. the last step in your recursive algorithm is the call to the recursive function). That way, you'll be warned at compile-time that there was an issue applying the optimization, preventing surprises at run-time.

Friday, April 5, 2013

Higher-Order Functions in Scala

The functions that most of us are familiar with take some type(s) as parameters and return some type as a result. These are called first order functions. Higher-order functions, however, are functions that take other functions as parameters and/or return a function as a result. In other words, higher-order functions act on other functions.

To demonstrate with a simple example, let's look at how we might define a first order function that takes two integer values and returns the sum of both values squared.


def sumOfSquares(a: Int, b: Int): Int = {
   a * a + b * b
}


Next, let's look at how we could refactor this to use a higher-order function. Here is a function that takes 3 parameters:
  • a function that takes an Int and returns an Int
  • an Int named a
  • an Int named b

def sumOfTwoOperations(f: Int => Int, a: Int, b: Int): Int = {
   f(a) + f(b)
}


Next, let's call the sumOfTwoOperations function, passing a "squared" function as the first parameter.


def squared(x: Int): Int = x * x

val result = sumOfTwoOperations(squared, 2, 5)   // result = 29


The beauty of higher-order functions is that now we can define another operation, such as "cubed", and pass that to the sumOfTwoOperations.

def cubed(x: Int): Int = x * x * x

val result = sumOfTwoOperations(cubed, 2, 5)   // result = 133