Decoding NumPy’s Dot Product: A Brief Exploration of Dimensional Wizardry | by Mario Namtao Shianti Larcher | Jul, 2023


Clarifying once and for all the confusion over NumPy’s dot product

Image generated with DreamStudio with the prompt “A chaotic, dark, gloomy, multidimensional world full of code wizards”.

Am I the only one who periodically gets confused when dealing with dimensions in NumPy? Today, while reading a Gradio’s documentation page, I came across the following code snippet:

sepia_filter = np.array([
[0.393, 0.769, 0.189],
[0.349, 0.686, 0.168],
[0.272, 0.534, 0.131],
])
# input_img shape (H, W, 3)
# sepia_filter shape (3, 3)
sepia_img = input_img.dot(sepia_filter.T) # <- why this is legal??
sepia_img /= sepia_img.max()

Hey, hey, hey! Why does the dot product of an image (W, H, 3) with a filter (3, 3) is legal? I asked ChatGPT to explain it to me, but it started giving me wrong answers (like saying this doesn’t work) or ignoring my question and started answering something else instead. So, there was no other solution than using my brain (plus reading the documentation, sigh).

If you are also a little confuse by the code above, continue reading.

From the NumPy dot product documentation (with minor modifications):

If a.shape = (I, J, C) and b.shape = (K, C, L), then dot(a, b)[i, j, k, l] = sum(a[i, j, :] * b[k, :, l]). Notice that the last dimension of “a” is equal to the second-to-last dimension of “b”.

Or, in code:

I, J, K, L, C = 10, 20, 30, 40, 50
a = np.random.random((I, J, C))
b = np.random.random((K, C, L))
c = a.dot(b)
i, j, k, l = 3, 2, 4, 5
print(c[i, j, k, l])
print(sum(a[i, j, :] * b[k, :, l]))

Output (same result):

13.125012901284713
13.125012901284713

To determine the shape of a dot product beforehand, follow these steps:

Step 1: Consider two arrays, “a” and “b,” with their respective shapes.

# Example shapes for arrays a and b
a_shape = (4, 3, 2)
b_shape = (3, 2, 5)
# Create random arrays with the specified shapes
a = np.random.random(a_shape)
b =…



Source link

Leave a Comment