Skip to main content

Scala Version of Approximation Algorithm for Knapsack Problem for Apache Spark

This is the Scala version of the approximation algorithm for the knapsack problem using Apache Spark.

I ran this on a local setup, so it may require modification if you are using something like a Databricks environment. Also you will likely need to setup your Scala environment.

All the code for this is at GitHub

First, let's import all the libraries we need.


import org.apache.spark._
import org.apache.spark.rdd.RDD
import org.apache.spark.SparkConf
import org.apache.spark.SparkContext._
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.sum

We'll define this object knapsack, although it could be more specific for what this is doing, it's good enough for this simple test.

object knapsack {


Again, we'll define the knapsack approximation algorithm, expecting a dataframe with the profits and weights, as well as W, a total weight.

  def knapsackApprox(knapsackDF: DataFrame, W: Double): DataFrame = {


Calculate the ratios of profit over weight, and sort them high to low ratio. Discard any weights that are already larger than the max knapsack size, W.

    val ratioDF = knapsackDF.withColumn("ratio", knapsackDF("values") / knapsackDF("weights"))
    val newRatioDF = (ratioDF
      .filter(ratioDF("weights") <= W)
      .sort(ratioDF("ratio").desc)
      )

Now we'll use SQL to add up all the partial sums of weights. A window function is another way this could work with SQL. This will tell us what can fit in the knapsack, and remember these are sorted by profit to weight ratio, high to low.


    newRatioDF.createOrReplaceTempView("tempTable")
    val partialSumWeightsDF = spark.sql("SELECT item, weights, values, ratio, sum(weights) OVER (ORDER BY ratio desc) as partSumWeights FROM tempTable")
    val partialSumWeightsFilteredDF = (
       partialSumWeightsDF
        .filter(partialSumWeightsDF("partSumWeights") <= W)
      )

And now return this new Dataframe, which will have only the objects that fit.

    partialSumWeightsDF
  }
}

So this will return the greedy solution, which is fast and easy use parallelism, but is not optimal. Parallel solutions to optimal knapsack algorithms, are often not as simple, but this was a good way to test out Spark using Scala.

And here is the test code, which is pretty self explanatory, the Github is some work in progress and I've some clean up to do.

import org.apache.spark.mllib.random.RandomRDDs._
import scala.collection.mutable.ListBuffer<- ------------------------------------------="" -="" 0.3="" 0.6="" 10.0="" 1="" 5.="" a="" alue="" and="" approximate="" approximation="" call="" countresult="" create="" data:="" data="" dataframe.="" dataframe="" display="" eights="" elected="" elements:="" elements="" end="" find="" for="" function="" greedy="" item="" item_="" k.tostring="" knapresults.show="" knapresults="knapsack.knapsackApprox(knapsackData," knapsack.="" knapsack="" knapsackdata.show="" knapsackdata="sc.parallelize(knapsackDataList).toDF(" knapsackdatalist="knapsackDataListBuffer.toList" knapsackdatalistbuffer="" make="" maximum="" n="" of="" original="" ount:="" pre="" println="" r.nextdouble="" r="" random="" results="" riginal="" s="" selected="" show="" size="" start="" test="" the="" to="" total:="" totals.="" totals="" val="" value="" values="" valuesresult.show="" valuesresult="knapResults.agg(sum(" w="" weight.="" weight="" weights="" weightsresult.show="" weightsresult="knapResults.agg(sum(" with="">
import org.apache.spark.mllib.random.RandomRDDs._
import scala.collection.mutable.ListBuffer

// Knapsack problem size.
val N = 10

// Random
val r = scala.util.Random

// Setup sample data for knapsack.
val knapsackDataListBuffer = ListBuffer[(String, Double, Double)]()
for (k <- 1 to N) {
  knapsackDataListBuffer += (("item_" + k.toString, r.nextDouble() * 10.0, r.nextDouble() * 10.0))
}
val knapsackDataList = knapsackDataListBuffer.toList

// Make a Dataframe with item(s), weight(s), and value(s) for the knapsack.
val knapsackData = sc.parallelize(knapsackDataList).toDF("item", "weights", "values")

// Display the original data
println("Original Data:")
knapsackData.show()
println("\r\n")

// Create a random maximum weight
val start = N * 0.3
val end = N * 0.6
val W = (math.random * (end - start) + start)

// Show the weight.
println("W: ")
println(W)
println("\r\n")

// Call the knapsack greedy approximation function, with data and size 5.
val knapResults = knapsack.knapsackApprox(knapsackData, W)

// Show the results Dataframe.
println("Selected Elements:")
knapResults.show()
println("\r\n")

// Find the totals.
val valuesResult = knapResults.agg(sum("values"))
val weightsResult = knapResults.agg(sum("weights"))
val countResult = knapResults.count()

// Show totals for selected elements of knapsack.
println("Value Total:")
valuesResult.show()
println("\r\n")
println("Weights Total:")
weightsResult.show()
println("\r\n")
println("Count:")
println(countResult)
println("\r\n")

And that is it, just create some random items, call the knapsackApprox(knapsackData, W) function, and print out the results. Note, I summed it outside of the main knapsack routine, which just finds the objects that satisfy the problem. Next tasks are: clean up the code for Scala, convert to window function, and complete the Java version.

Popular posts from this blog

Getting back into parallel computing with Apache Spark has been great, and it has been interesting to see the McColl and Valiant BSP (Bulk Synchronous Parallel) model finally start becoming mainstream beyond GPUs.

While Spark can be some effort to setup on actual clusters and does have an overhead, thinking that these will be optimized over time and Spark will become more and more efficient. 
I have started a GitHub repo for Spark snippets if any are of interest as Apache Spark moves forward 'in parallel' to the HDFS (Hadoop Distributed File System).

Darrell Ulm Git Hub Profile Page

This is the software development profile page of Darrell Ulm for GitHub including projects and code for these languages C, C++, PHP, ASM, C#, Unity3d and others.


Here is the link: https://github.com/drulm

The content can be found at these other sites: Profile, Wordpress, and Tumblr.

Certainly we're seeing more and more projects on Github or moving there and wondering how much of the software project domain they currently have percentage-wise.