Coding for Neon
102159
Issue 04
Copyright © 2020 Arm Limited (or its affiliates). All rights reserved.
Non-Confidential
Page 35 of 46
the 16 elements from the second matrix. Each 128-bit register holds four 32-bit values, representing
an entire matrix column.
Similarly, the following code shows how we use the ST1 instruction to store the result back to
memory:
ST1 {V8.4S, V9.4S, V10.4S, V11.4S}, [X0]
The following code shows how we calculate a column of results using just four Neon multiply
instructions:
FMUL V8.4S, V0.4S, V4.S[0] // rslt col0 = (mat0 col0) * (mat1 col0 elt0)
FMLA V8.4S, V1.4S, V4.S[1] // rslt col0 += (mat0 col1) * (mat1 col0 elt1)
FMLA V8.4S, V2.4S, V4.S[2] // rslt col0 += (mat0 col2) * (mat1 col0 elt2)
FMLA V8.4S, V3.4S, V4.S[3] // rslt col0 += (mat0 col3) * (mat1 col0 elt3)
The first FMUL instruction implements the operation that is highlighted in the previous diagram.
Matrix elements x0, x1, x2, and x3 (in the four lanes of register V0) are each multiplied by y0 (element
0 in register V4), and the result stored in V8.
Subsequent FMLA instructions operate on the other columns of the first matrix, multiplying by
corresponding elements of the first column of the second matrix. Results are accumulated into V8 to
give the first column of values for the result matrix.
If we only need to calculate a matrix-by-vector multiplication, the operation is now complete.
However, to complete the matrix-by-matrix multiplication, we must execute three more iterations.
These iterations use values y4 to yF in registers V5 toV7.
The following code shows the full implementation of a four-by-four floating-point matrix multiply:
matrix_mul_float:
LD1 {V0.4S, V1.4S, V2.4S, V3.4S}, [X1] // load all 16 elements of matrix 0 into
// V0-V3, four elements per register
LD1 {V4.4S, V5.4S, V6.4S, V7.4S}, [X2] // load all 16 elements of matrix 1 into
// V4-V7, four elements per register
FMUL V8.4S, V0.4S, V4.S[0] // rslt col0 = (mat0 col0) * (mat1 col0 elt0)
FMUL V9.4S, V0.4S, V5.S[0] // rslt col1 = (mat0 col0) * (mat1 col1 elt0)
FMUL V10.4S, V0.4S, V6.S[0] // rslt col2 = (mat0 col0) * (mat1 col2 elt0)
FMUL V11.4S, V0.4S, V7.S[0] // rslt col3 = (mat0 col0) * (mat1 col3 elt0)
FMLA V8.4S, V1.4S, V4.S[1] // rslt col0 += (mat0 col1) * (mat1 col0 elt1)
FMLA V9.4S, V1.4S, V5.S[1] // rslt col1 += (mat0 col1) * (mat1 col1 elt1)
FMLA V10.4S, V1.4S, V6.S[1] // rslt col2 += (mat0 col1) * (mat1 col2 elt1)
FMLA V11.4S, V1.4S, V7.S[1] // rslt col3 += (mat0 col1) * (mat1 col3 elt1)
FMLA V8.4S, V2.4S, V4.S[2] // rslt col0 += (mat0 col2) * (mat1 col0 elt2)
FMLA V9.4S, V2.4S, V5.S[2] // rslt col1 += (mat0 col2) * (mat1 col1 elt2)
FMLA V10.4S, V2.4S, V6.S[2] // rslt col2 += (mat0 col2) * (mat1 col2 elt2)
FMLA V11.4S, V2.4S, V7.S[2] // rslt col3 += (mat0 col2) * (mat1 col2 elt2)
FMLA V8.4S, V3.4S, V4.S[3] // rslt col0 += (mat0 col3) * (mat1 col0 elt3)
FMLA V9.4S, V3.4S, V5.S[3] // rslt col1 += (mat0 col3) * (mat1 col1 elt3)
FMLA V10.4S, V3.4S, V6.S[3] // rslt col2 += (mat0 col3) * (mat1 col2 elt3)
FMLA V11.4S, V3.4S, V7.S[3] // rslt col3 += (mat0 col3) * (mat1 col3 elt3)
ST1 {V8.4S, V9.4S, V10.4S, V11.4S}, [X0] // store all 16 elements of result
RET // return to caller