Paper summary: Federated Learning

On Thursday, April 6, Google announced Federated Learning.  The announcement didn't make quite a splash, but I think this is potentially transformative. Instead of uploading all the data from the smartphones to the cloud for training the model in the datacenter, federated learning enables in-situ training at the smartphones themselves. The datacenter is still involved but it is involved for just aggregating the smartphone-updated local models in order to construct the new/improved global model.

This is a win-win-win situation. As a smartphone user, your privacy is preserved since your data remains on your device, but you still get the benefits of machine learning on your smartphone. Google gets what it needs: it perpetually learns from cumulative user experience and improves its software/applications. Google collects insights without collecting data (and some of these insights may still be transferable to advertising income). Secondly, Google also outsources the training to the users' smartphones: it makes the smartphones work on updating the model, rather than using servers in its datacenters. This may seem like penny-pinching, but if you consider the sheer number of Android smartphones out there (more than 1.5 billion according to numbers by the end of 2015), the computation savings are huge. Notice that Google is winning twice, while you win only once. :-)

After skimming the Google announcement, I was excited, because I had predicted this on Jan 2016. When I read the TensorFlow whitepaper, I was struck by a peculiar emphasis on heterogenous device support of TensorFlow, especially on the smartphone. I predicted that Google is up to something, more than just inference/serving layer support on smartphones. I got the mechanism for this wrong, since I didn't have machine learning experience then.

The Google announcement links to some papers, and I read the most relevant paper carefully. Below is my summary of the paper: Communication-Efficient Learning of Deep Networks from Decentralized Data.

Applications that are ideal fit for federated learning

The paper pitches smartphone applications as the ideal domain for federated learning.
For smartphones, the upload transfer is a bottleneck. Instead of uploading the data to the cloud for training, let the smartphone train and update the model and transmit it back. This makes sense when the model-update is smaller than the data. There is another Google paper that provides tricks for further optimizing/compressing the size of the learned-model at the smartphone for transmitting back to the cloud. Since the upload rate is about 1/3rd of the download rate, such techniques are beneficial.

A second big benefit for the smartphone applications domain is preserving the private data of the smartphone user. Finally, in the smartphone applications domain the labels on the data is available or inferable from user interaction (albeit in an application dependent way).

Concrete examples of these applications are "image classification: predicting which images are most likely to be viewed and shared", and "language modeling: improving voice recognition and text entry on smartphone keyboard." Google is already using federated learning in the GBoard Keyboard on Android for improving text entry. On device training uses a miniature version of TensorFlow.

Related work

There has been work on iteratively averaging locally trained models, but they are inside datacenter, not at the edge. See: Sixin Zhang, Anna E Choromanska, and Yann LeCun. Deep learning with elastic averaging SGD. In NIPS. 2015. and also this one: Ryan McDonald, Keith Hall, and Gideon Mann. Distributed training strategies for the structured perceptron. In NAACL HLT, 2010.

There has been work motivated from privacy perspective, but with limited empirical results. And there has been work in convex setting, for distributed consensus algorithm (not the distributed systems version of the problem but the machine learning version.)

The contributions of the paper

The paper introduces the FedAvg algorithm for federated learning. The algorithm is not an intellectually challenging innovation, as it prescribes a small variation to the traditional SGD training. So the paper is careful not to claim too much credit for the algorithm. Instead the paper distributes its contributions to items 1 and 3 below, and downplays 2 in comparison: "1) the identification of the problem of training on decentralized data from mobile devices as an important research direction; 2) the selection of a straightforward and practical algorithm that can be applied to this setting; and 3) an extensive empirical evaluation of the proposed approach."

The algorithm

The algorithm uses a synchronous update scheme that proceeds in rounds of communication. There is a fixed set of K clients, each with a fixed local dataset. At the beginning of each round, a random fraction C of workers is selected, and the server sends the current global algorithm state to each of these workers (e.g., the current model parameters). Only a fraction of workers are selected for efficiency, as the experiments show diminishing returns for adding more workers beyond a certain point. Each selected worker then performs local computation based on the global state and its local dataset, and sends an update to the server. The server then applies these updates to its global state, and the process repeats.

So here are the high-level steps in the algorithm:

  1. Workers are sent the model by the server
  2. Workers compute an updated model based on their local data
  3. The updated models are sent from the workers to the server
  4. The server aggregates these models (by averaging) to construct the new global model

These steps are almost the same in a traditional ML/DL learning with a parameter-server and workers. But there are some minor differences. Difference in step 1: Not all workers, but a subset of workers are chosen. Differences in step 2: Workers are also producers of data, they work on the data they produce. Workers may do multiple iterations on the local-model-update. Difference in step 3: The model, not the gradients, is transmitted back.

For step 3, I don't know why the model is sent rather than the gradients. The paper presents a derivation to argue why both are equivalent, but then does not provide any justification for transmitting model instead of the gradients. There is no explanation nor experiments about comparing the two approaches.

Here is the algorithm: (The paper calls workers as clients.)

The amount of computation is controlled by three key parameters: C, the fraction of workers that perform computation on each round; E, then number of training passes each worker makes over its local dataset on each round; and B, the local minibatch size used for the worker updates. B =infinity indicates that the full local dataset is treated as a single minibatch.

And this is the biggest difference in the algorithm from traditional synchronous SGD:
In each synchronous round, the workers can perform multiple rounds of local-model update before uploading the model back to the cloud. Since the local-data is small, computing/iterating over it is fast. And since communication is already high-latency, the workers may as well do many local computation rounds to make this worthwhile.

The paper makes a big deal of the unbalanced data distribution and especially the  non-iid (non- independent & identically distributed) data at the workers. However, there is nothing special in the FedAvg algorithm to address the non-iid data challenge. It works because the SGD tolerates noise. And most likely having many workers participate at each round helps a lot. Even with C=0.1, the number of smartphones used for training can be amazingly high. In the experimental results,   100 workers are used for MNIST image classification application, and  1146 workers are used for Shakespeare dataset language modeling application. This is way more workers used in traditional ML, certainly for the MNIST and Shakespeare datasets. So the sheer number of workers helps compensate the non-iid challenge.


Can Federated Learning bootstrap from zero-trained model?
Before reading the paper, I thought "maybe the smartphones are not initialized from random parameters, but with partially or mostly trained parameters". The experiments suggests otherwise: The smartphones are started from an untrained model. However, Figure 1 in the paper shows that the smartphones (i.e., workers) should get started with the same common initialization, because independent initialization does not converge.

How long is a round?
The experiments investigate how many rounds are needed to converge, but there is no explanation in the paper about the length of a round. In my estimation, the length of a round should be around 5 minutes, or so. In practical use, Google employs as workers the smartphones that are plugged to a wall-outlet and connected to a WI-FI. (Android phones provide this information to Google and a subset among them is randomly selected as workers. In our PhoneLab project, we also uploaded data from phones only when they were plugged in.) Since this is synchronous SGD, the round progresses with the rate of the slowest worker. So they must be using backup workers (maybe 5% or 10%) to combat the straggler problem.

FedAvg uses synchronous SGD, but would it also work with asynchronous rounds?
Unfortunately there are no experiments that investigate this question.

Experimental results

The experiments use the MNIST dataset image classification, CIFAR 10 image classification, and Shakespeare dataset for language modeling and prediction of the next character (alphabet character that is).

The paper says this: "Our goal is to use additional computation in order to decrease the number of rounds of communication needed to train a model. The speedups we achieve are due primarily to adding more computation on each client, once a minimum level of parallelism over clients is used." They achieve this by increasing E (or decreasing B), once C is saturated (which tend to happen for C=0.1 in their setup).

Even the pathologically biased distribution of local datasets  works great because number of C is high, and provides smoothing. "Both standard SGD and FedAvg with only one client per round (C = 0), demonstrate significant oscillations in accuracy, whereas averaging over more clients smooths this out."


Popular posts from this blog

Learning a technical subject

Foundational distributed systems papers

Learning about distributed systems: where to start?

Strict-serializability, but at what cost, for what purpose?

CockroachDB: The Resilient Geo-Distributed SQL Database

Amazon Aurora: Design Considerations + On Avoiding Distributed Consensus for I/Os, Commits, and Membership Changes

Anna: A Key-Value Store For Any Scale

Warp: Lightweight Multi-Key Transactions for Key-Value Stores

The Seattle Report on Database Research (2022)

Graviton2 and Graviton3