It appears that in order to multiply two matrices we need to multiply each element in a row from a matrix by each element in a column of another matrix. Once we are done we sum the products. This is called a dot product. So, let’s start with that.
function getDotProduct(row::Vec{Int}, col::Vec{Int})
@assert length(row) == length(col) "row & col must be of equal length"
return map(*, row, col) |> sum
end
Note. Thanks to the previously defined (Section 1) type synonyms we saved some typing and used
Vec{Int}
instead ofVector{Int}
. We will use such small convenience(s) throughout the book. The type synonyms are defined in Section 1 and in the code snippets for each chapter.
First, we place a simple assumption check with the assert. Then we multiply each element of row
by each element of col
with map
. Map
applies a function (its first argument) to every element of a collection (its second argument), like so:
# adds 10 to each element of a vector
map(x -> x + 10, [1, 2, 3])
[11, 12, 13]
Here we used a vector ([1, 2, 3]
) and applied an anonymous function (x -> x + 10
) to each of its elements. The function accepts one argument (x
), adds 10 to it (x + 10
) and returns (->
) that value. Since x
will become every element of the vector [1, 2, 3]
then in effect 10 will be added to the each component of the vector and the results will be collected into a new vector (the old vector is not changed). Interestingly, we may also use a function that accepts two arguments and apply this function to parallel elements of two vectors, like so:
map((x, y) -> x * y, [1, 2, 3], [10, 100, 1000])
[10, 200, 3000]
Here, x
becomes every value of [1, 2, 3]
and y
every value of [10, 100, 1000]
vector. Given that *
is just a syntactic sugar for *(x, y)
we may simply place *
alone.
map(*, [1, 2, 3], [10, 100, 1000])
[10, 200, 3000]
Since we calculate a dot product, then as an alternative (to live up to its name) we could also use the dot operator syntax in our getDotProduct
function like so:
[1, 2, 3] .* [10, 100, 1000]
[10, 200, 3000]
Anyway, once we got the products vector we send it (|>
) as an input to sum
.
OK, time for multiplication itself.
function multiply(m1::Matrix{Int}, m2::Matrix{Int})::Matrix{Int}
nRowsMat1, nColsMat1 = size(m1)
nRowsMat2, nColsMat2 = size(m2)
@assert nColsMat1 == nRowsMat2 "the matrices are incompatible"
result::Matrix{Int} = zeros(nRowsMat1, nColsMat2)
for r in 1:nRowsMat1
for c in 1:nColsMat2
result[r, c] = getDotProduct(m1[r,:], m2[:, c])
end
end
return result
end
The above is a translation of the algorithm from the links provided in the task description earlier on. First we get our matrices dimensions and perform a compatibility check with @assert
. Then we initialize an empty matrix (result
) with the appropriate dimensions (we use zeros
so 0s are the placeholders stored in its cells). Finally, we get the dot products of every row (for r
) in m1
by every column (for c
) in m2
and place them to the appropriate cells in the result
matrix.
Alternatively, if you are not a fan of nesting, you may use Julia’s simplified nested for loop syntax. It works the same as the previous code snippet.
function multiply(m1::Matrix{Int}, m2::Matrix{Int})::Matrix{Int}
nRowsMat1, nColsMat1 = size(m1)
nRowsMat2, nColsMat2 = size(m2)
@assert nColsMat1 == nRowsMat2 "the matrices are incompatible"
result::Matrix{Int} = zeros(nRowsMat1, nColsMat2)
for r in 1:nRowsMat1, c in 1:nColsMat2
result[r, c] = getDotProduct(m1[r,:], m2[:, c])
end
return result
end
Anyway, let’s give it a swing.
# Math is Fun examples
a = [1 2 3; 4 5 6]
b = [7 8; 9 10; 11 12]
multiply(a, b)
2×2 Matrix{Int64}:
58 64
139 154
Looks good, and
# Khan Academy examples
c = [-1 3 5; 5 5 2]
d = [3 4; 3 -2; 4 -2]
multiply(c, d)
2×2 Matrix{Int64}:
26 -20
38 6
Appears to be correct as well.
And now for a few tests against the build in *
operator.
(a * b) == multiply(a, b)
(c * d) == multiply(c, d)
true
We can’t complain. It appears that we managed to solve this task in like 15 lines of code and without over-engineering it too much. It’s all thanks to the Julia’s nice and terse syntax.