PyTorch's compilation and matrix multiplication.
PyTorch is a popular open-source machine learning library known for its flexibility and dynamic computation graph. However, when it comes to optimizing certain mathematical operations, it falls short. One example of this is its approach to matrix multiplication, which, in some cases, lacks basic optimizations. In this article, we'll explore PyTorch's shortcomings when it comes to optimizing matrix multiplication and discuss a specific scenario where it doesn't use the most efficient approach.
Matrix multiplication is a fundamental operation in various machine learning and scientific computing tasks. It forms the backbone of many deep learning models, including neural networks. To perform matrix multiplication, PyTorch offers several methods, including `torch.mm()`, `torch.matmul()`, and the `@` operator. While these methods are versatile and easy to use, they may not always perform optimally.
Matrix multiplication can be a computationally intensive task, especially when dealing with large matrices. In a naive approach to matrix multiplication, to compute the product of two matrices, A and B, you would typically iterate over the rows and columns of the matrices and perform element-wise multiplications and additions. This approach, also known as the standard matrix multiplication algorithm, has a time complexity of O(n^3), where n is the size of the matrices.
To improve efficiency, more advanced algorithms such as Strassen's algorithm and the Coppersmith–Winograd algorithm have been developed, which reduce the number of elementary multiplications needed to compute the result. In the context of deep learning, PyTorch utilizes highly optimized and efficient matrix multiplication libraries like Intel MKL or NVIDIA cuBLAS when available, which can significantly speed up these operations.
However, when it comes to deciding which matrices to optimize, PyTorch can easily miss basic optimization opportunities.
Let's look at a specific example where PyTorch fails to apply a basic optimization. Suppose we want to compute the fourth power of a matrix A. One straightforward approach is to multiply the matrix by itself four times:
A^4 = A * A * A * A
In this case, PyTorch's default matrix multiplication method will perform a matrix multiplication three times, resulting in four multiplications, even though it can be done with just two multiplications. This is because PyTorch does not recognize the inherent mathematical symmetry and uses a naively repeated multiplication approach.
For a random 1000x1000 matrix it takes ~92ms to perform matmul4
Attentive reader may suggest using pytorch.compile
but, unfortunately, this doesn’t make much difference
To compute the fourth power of a matrix more efficiently, one can exploit the properties of matrix exponentiation. Using exponentiation by squaring, you can compute A^4 as follows:
A^4 = (A^2)^2
This method reduces the number of matrix multiplications, and in this case, only two multiplications are required: A^2 and (A^2)^2. Using this optimized approach can significantly reduce the computational burden when dealing with large matrices.
The code to accomplish this is trivial
and results in a desired speedup
The same can be achieved by using
torch.matrix_power(mtx, 4)
that delivers similar performance
Note that this optimization relies on associativity and since floating point multiplications are non-associative, the result of optimized version is generally not equal to non-optimized one. This doesn’t mean it’s not correct and the difference is certainly much smaller than the error introduced by quantization and other hacks that are often employed to speed up learning or inference. On the bright side performance gap between optimized and non-optimized versions grows with matrix dimensions and the power.
While PyTorch is a powerful and widely used library for deep learning and scientific computing, it's not immune to optimization issues. In the case of matrix multiplication, PyTorch may not always apply the most efficient high-level algorithms. To achieve the best performance in such cases, users must be aware of these limitations and apply their own optimizations when necessary. Understanding the underlying algorithms and mathematical properties is crucial to harnessing PyTorch's full potential and ensuring efficient computation in complex machine learning tasks.
You can play with above examples using colab playground.
ncG1vNJzZmirn5vBuK3RnpmirKNjwLau0q2YnKNemLyue89op7Ksn6ewqb%2BMnKamqJmhrrW1zqdkmqaUYrqiwNGir2alpaHBqrzLopqarJmkuw%3D%3D