Minipost: Exploiting Spark as a ThreadPool

Apache Spark is a great tool for massively distributed data analysis and manipulation. It also features a machine learning library, but in the python variant, it is unfortunately nowhere near as good as the awesome scikit-learn library, especially in combination with numpy and scipy, while the linear algebra tools in PySpark make it seem that the coders didn’t even try to bother. This stands next to the fact that for now, my models are small enough to train them using the aforementioned python libraries. In my current setting, I run a single training procedure with one single parameter 100 times. With seven parameter values, this means I have to train my model 700 times. With an average training time of 5min, this would mean that I’d have to wait 3500min, which equals about a single day of training time.

To speed this up without bothering too much about parallelization code, I exploited Spark as a ThreadPool, i.e. I mapped my training procedure on a list of parameter combinations and parallelize it as an RDD. On our small cluster, I can now have 70 jobs run in parallel, while also using Spark to combine and reduce my results. This sounds easy, but it still needs a simple trick for that: I have to repartition my parameter RDD so that every single item is processed on a single worker thread. I can do this with the glom() method. Normally, my code looks something along the lines of

def train(data, modelparams, moreparams):
    doSomething
    return model

for modelparams in modelparamspace:
    models = [eval(train(data, modelparams)) for _ in range(runs)]

print(np.mean(models), np.std(models))

Instead of now, I do something like the following:

partitions_count = len(modelparamspace) * runs
paramRDD = sc.parallelize(modelparamspace * runs)

data_bc = sc.broadcast(data)

paramRDD.zipWithIndex()\
    .map(lambda x: (x[1], x[0]))\
    .partitionBy(partitions_count, lambda k: k)\
    .glom()\
    .map(lambda param: (param[0][1], ) + train(data_bc, param[0][1]))\
    .collect()

First, I have to index my params, so that two runs with the same parameter configuration also gets executed twice. I then swap index and params, as the index gets appended last. The partitionBy statement then creates the desired number of partitions partitions_count with the identity hash function, so that each parameter configuration gets its own partition. The glom statement ensures that all elements inside of a partition are aggregated into a list (take a look at the documentation here), so then the train method can be applied on each of these lists. I broadcasted the data across the cluster beforehand. To identify the corresponding execution parameters for each result, I add them in the final result step. By calling collect(), I then gather all my data on the driver again, where I can process or plot them.

The same thing could now be applied to e.g. training a Word2Vec model on each worker using gensim instead of the Spark variant. Alternatively, one could train a TensorFlow model on each worker, which works like a charm when every worker also has a powerful GPU.