What is the difference between GPU inference and training?

NADDOD Holly Fiber Optic Technician Supervisor Oct 4, 2023

Firstly, training not only requires storing model parameters but also storing gradients, optimizer states, intermediate states (activation) of each layer during forward propagation. The latter requires more memory than parameters and has a greater demand for model memory.


Secondly, training is a holistic task where the intermediate results of parallelized forward propagation need to be stored for backward propagation. To save memory, pipeline parallelism is used, but the more pipeline stages there are, the more intermediate states need to be stored, exacerbating the memory shortage. In contrast, there is no relationship between input data in inference tasks, and the intermediate states of each layer during forward propagation do not need to be saved. Therefore, pipeline parallelism does not require storing many intermediate states.


Let's calculate how much computational power is needed for inference. In the previous estimation of training computational power, two things were ignored for simplicity: KV Cache and memory bandwidth.


KV Cache

What is KV Cache? For each input prompt, when calculating the output of the first token, the attention for each token must be computed from scratch. However, in the generation of subsequent tokens, self-attention needs to be computed, which involves the input prompt and the attention of the previously generated tokens. This requires the K and V values of each previous token. Since the parameter matrices for each layer are unchanged, only the K and V values of the most recently generated token need to be computed from scratch, while the K and V values of the input prompt and previously generated tokens are the same as the previous round.


At this point, we can cache the K and V matrices for each layer so that they do not need to be recomputed when generating the next token. This is called KV Cache. The Q matrix is different every time and has no caching value. The selective saving of forward  mentioned earlier in training is a memory-saving trick achieved by trading computation. Similarly, KV Cache is a computation-saving trick achieved by trading memory.


How much storage capacity does KV Cache require? For each layer, the K and V matrices for each token are of size embedding size, multiplied by the number of tokens and batch size, giving the required storage capacity for the KV Cache of that layer. It is important to remember the batch size, as it is a coefficient in the calculation of storage and computation throughout almost all stages of forward and backward propagation.


For example, if the batch size is 4, in LLaMA 2 70B, assuming the number of input and output tokens reaches the model's limit of 4096, the KV Cache for 80 layers would require a total of 2 (K, V) * 80 * 8192 * 4096 * 8 * 2B = 80 GB. If the batch size is larger, then the space occupied by KV Cache will exceed the space occupied by the parameters, which is 140 GB.


How much computational power can be saved by KV Cache? Computing the K and V matrices for each layer requires a total of 2 (K, V) * 2 (mult, add) * embedding size * embedding size = 4 * 8192 * 8192 computations. Multiplying this by the number of previously processed tokens, the number of layers, and the batch size, we get 4096 * 80 * 8 * 4 * 8192 * 8192 = 640 Tflops. This means that for every byte stored, 16K computations are saved, which is quite cost-effective.


In fact, KV Cache saves much more than that. The process of computing the K and V matrices is a typical memory-intensive process that requires loading the K and V parameter matrices for each layer. If no caching is done and assuming the prompt length is very short while the output length approaches the maximum token length of 4096, when we reach the last token, simply computing the K and V matrices for each previous token would require reading memory 40 trillion times (4096 * 80 * 2 * 8192 * 8192) with each read being 2 bytes. It's important to note that the memory bandwidth of H100 is only 3.35 TB/s, and for the last token, it would take several tens of seconds just to perform repetitive computations. As a result, the output of tokens would become slower and the overall output time would be on the order of the square of the output length, making it impractical.


Whether Inference is Compute-intensive or Memory-intensive

We can calculate the computational power required for inference, which is straightforward: approximately 2 times the number of output tokens multiplied by the number of parameters in flops. For more details, you can refer to the figure below, sourced from Lequn Chen's work on Dissecting Batching Effects in GPT Inference.

Transformer reasoning process


However, computational power alone doesn't explain everything as the model also needs to access GPU memory, and memory bandwidth can also become a bottleneck. At the very least, we need to read the parameters from memory, right? Estimating memory bandwidth is simple: memory access = number of parameters * 2 bytes. Some intermediate results can be stored in the cache, but the part that doesn't fit in the cache still occupies memory bandwidth, which we won't consider for now.


If there is no batch input, meaning the model specifically serves one prompt with a batch size of 1 and the overall context length is very short (e.g., only 128), then throughout the entire inference process, for each parameter loaded (2 bytes), there would only be 128 multiplications and additions. In this case, the ratio of computational flops to memory bytes accessed is only 128. Essentially, any GPU would become memory-bound in this scenario, with most of the time spent on loading memory.


For 4090, the ratio of computational flops to memory bandwidth is 330/1 = 330, and for H100, the ratio is 1979/3.35 = 590. This means that if the number of tokens in the context is less than 330 or 590, memory access will become a bottleneck.


Although the theoretical upper limit of LLaMA 2 is 4096 tokens, many input prompts do not require that many tokens, so memory access can become a bottleneck. In this case, we rely on batch size to compensate. Batch processing in inference means grouping together prompts that arrive at the backend service almost simultaneously. Don't worry, the processing of different prompts within a batch is completely independent, so there won't be any interference. However, the outputs of these prompts will be synchronized, with each prompt in the batch outputting one token per round. Therefore, if some prompts finish outputting first, they will have to wait for the others to finish, resulting in some computational waste.


Some may wonder why batch processing is needed if the computational power required is the same as processing them individually. The answer lies in memory bandwidth.


If a large number of prompts arrive at the server simultaneously, is a larger batch size better? Not necessarily, because the size of the KV Cache is proportional to the batch size. If the batch size is large, the GPU memory occupied by the KV Cache can be substantial. For example, in LLaMA-2 70B, each prompt would occupy 5 GB of KV Cache. If the batch size is set to 32, the KV Cache would occupy 160 GB of GPU memory, which is larger than the parameters themselves.


In summary, there are significant differences in storage requirements and computational approaches between GPU inference and training. Training tasks require storing model parameters, gradients, and intermediate states, while in inference tasks, the input data is independent of each other, and the intermediate states do not need to be saved. Additionally, training tasks involve backpropagation and pipelined parallel computing, while inference tasks focus more on computational efficiency and the utilization of memory bandwidth. Therefore, different storage and computation strategies need to be considered for GPU inference and training to optimize performance and resource utilization.


The importance of optical transceivers in GPU inference and training lies in their high-speed data transmission capabilities, providing high bandwidth and low-latency fiber optic communication to accelerate data transfer between GPUs and host systems, thereby improving system performance.


Training Phase

During the training phase of deep learning, a large amount of data is input into the GPU for model training. The GPU accelerates the training process through its parallel computing capabilities. Training data is typically stored in local storage devices such as hard drives or solid-state drives and interacts with the GPU through the host system.


For large-scale deep learning training tasks, there is a significant amount of data and a need for high-bandwidth and low-latency data transfer. In such cases, optical transceivers can play a role. Optical transceivers connect between the host system and the GPU, transmitting data through optical fibers. They provide high-speed data transmission capabilities to meet the throughput and latency requirements of large-scale training tasks.


Inference Phase

During the inference phase of deep learning, a trained model is used to make predictions or classify new data. The GPU performs inference tasks with its high parallel computing capabilities, quickly processing input data and generating results. The inference process typically requires low latency and high throughput, especially for real-time applications and large-scale inference tasks.


In the inference phase, optical transceivers can be used for data input and output. For example, input data can be transmitted from storage devices to the GPU through optical fibers, and the GPU performs inference processing on the data and transmits the results back to the host system or other target devices via optical fibers. The high-speed transmission capabilities of optical transceivers ensure efficient and real-time inference processes.



As a leading provider of comprehensive optical networking solutions, NADDOD is committed to providing users with innovative computing and networking solutions. They offer optimal solutions including switches, AOC/DAC/optical transceivers, intelligent network cards, DPUs, and GPUs, enabling customers to significantly improve their business acceleration capabilities with low cost and outstanding performance.