Decision Tree such as C4.5 is easy to parallel. Following is an example.
This is a non-parallel version:
public void learnFromDataSet(Iterable> dataset){ for(Sample sample : dataset){ model.addSample((MapBasedBinarySample )sample); } Queue > Q = new LinkedList >(); TreeNode root = model.selectRootTreeNode(); model.addTreeNode(root); Q.add(root); while (!Q.isEmpty()){ TreeNode v = Q.poll(); if(v.getDepth() >= model.getMaxDepth()){ continue; } FeatureSplit featureSplit = model.selectFeature(v); if(featureSplit.getFeatureId() == null){ continue; } v.setFeatureSplit(featureSplit); Pair , TreeNode > children = model.newTreeNode(v, featureSplit); TreeNode leftNode = children.getKey(); TreeNode rightNode = children.getValue(); if(leftNode != null && leftNode.getSampleSize() > model.getMinSampleSizeInNode()){ v.setLeft(leftNode); model.addTreeNode(leftNode); Q.add(leftNode); } if(rightNode != null && rightNode.getSampleSize() > model.getMinSampleSizeInNode()){ v.setRight(rightNode); model.addTreeNode(rightNode); Q.add(rightNode); } } }
And this is a parallel version:
public class NodeSplitThread implements Runnable{ private TreeNodenode = null; private Queue > Q = null; public NodeSplitThread(TreeNode node, Queue > Q){ this.node = node; this.Q = Q; } @Override public void run() { if(node.getDepth() >= model.getMaxDepth()){ return; } FeatureSplit featureSplit = model.selectFeature(node); if(featureSplit.getFeatureId() == null){ return; } node.setFeatureSplit(featureSplit); Pair , TreeNode > children = model.newTreeNode(node, featureSplit); TreeNode leftNode = children.getKey(); TreeNode rightNode = children.getValue(); if(leftNode != null && leftNode.getSampleSize() > model.getMinSampleSizeInNode()){ node.setLeft(leftNode); model.addTreeNode(leftNode); Q.add(leftNode); } if(rightNode != null && rightNode.getSampleSize() > model.getMinSampleSizeInNode()){ node.setRight(rightNode); model.addTreeNode(rightNode); Q.add(rightNode); } } } public List > pollTopN(Queue > Q, int n){ List > ret = new ArrayList >(); for(int i = 0; i < n; ++i){ if(Q.isEmpty()) break; TreeNode node = Q.poll(); ret.add(node); } return ret; } @Override public void learnFromDataSet(Iterable > dataset){ for(Sample sample : dataset){ model.addSample((MapBasedBinarySample )sample); } Queue > Q = new ConcurrentLinkedQueue >(); TreeNode root = model.selectRootTreeNode(); model.addTreeNode(root); Q.add(root); ExecutorService threadPool = Executors.newFixedThreadPool(10); while (!Q.isEmpty()){ List > nodes = pollTopN(Q, 10); List tasks = new ArrayList (nodes.size()); for(TreeNode node : nodes){ Future task = threadPool.submit(new NodeSplitThread(node, Q)); tasks.add(task); } for(Future task : tasks){ try { task.get(); } catch (InterruptedException e) { continue; } catch (ExecutionException e) { continue; } } } threadPool.shutdown(); try { threadPool.awaitTermination(60, TimeUnit.SECONDS); } catch (InterruptedException e) { threadPool.shutdownNow(); Thread.interrupted(); } threadPool.shutdownNow(); }