arrays - Efficient reduction of multiple tensors in Python -


i have 4 multidimensional tensors v[i,j,k], a[i,s,l], w[j,s,t,m], x[k,t,n] in numpy, , trying compute tensor z[l,m,n] given by:

z[l,m,n] = sum_{i,j,k,s,t} v[i,j,k] * a[i,s,l] * w[j,s,t,m] * x[k,t,n]

all tensors relatively small (say less 32k elements in total), need perform computation many times, function have little overhead possible.

i tried implement using numpy.einsum this:

z = np.einsum('ijk,isl,jstm,ktn', v, a, w, x) 

but slow. tried following sequence of numpy.tensordot calls:

z = np.zeros((a.shape[-1],w.shape[-1],x.shape[-1])) s in range(a.shape[1]):   t in range(x.shape[1]):     res = np.tensordot(v, a[:,s,:], (0,0))     res = np.tensordot(res, w[:,s,t,:], (0,0))     z += np.tensordot(res, x[:,s,:], (0,0)) 

inside of double loop sum on s , t (both s , t small, not of problem). worked better, still not fast except. think may because of operations tensordot needs perform internally before taking actual product (e.g. permuting axes).

i wondering if there more efficient way implement kind of operations in numpy. wouldn't mind implementing part in cython, i'm not sure right algorithm use.

using np.tensordot in parts, can vectorize things -

# perform "np.einsum('ijk,isl->jksl', v, a)" p1 = np.tensordot(v,a,axes=([0],[0]))         # shape = jksl  # perform "np.einsum('jksl,jstm->kltm', p1, w)" p2 = np.tensordot(p1,w,axes=([0,2],[0,1]))    # shape = kltm  # perform "np.einsum('kltm,ktn->lmn', p2, w)" z = np.tensordot(p2,x,axes=([0,2],[0,1]))     # shape = lmn 

runtime test , verify output -

in [15]: def einsum_based(v, a, w, x):     ...:     return np.einsum('ijk,isl,jstm,ktn', v, a, w, x) # (l,m,n)     ...:      ...: def vectorized_tdot(v, a, w, x):     ...:     p1 = np.tensordot(v,a,axes=([0],[0]))        # shape = jksl     ...:     p2 = np.tensordot(p1,w,axes=([0,2],[0,1]))   # shape = kltm     ...:     return np.tensordot(p2,x,axes=([0,2],[0,1])) # shape = lmn     ...:  

case #1 :

in [16]: # input params     ...: i,j,k,l,m,n = 10,10,10,10,10,10     ...: s,t = 3,3 # problem states : "both s , t small".     ...:      ...: # input arrays     ...: v = np.random.rand(i,j,k)     ...: = np.random.rand(i,s,l)     ...: w = np.random.rand(j,s,t,m)     ...: x = np.random.rand(k,t,n)     ...:   in [17]: np.allclose(einsum_based(v, a, w, x),vectorized_tdot(v, a, w, x)) out[17]: true  in [18]: %timeit einsum_based(v,a,w,x) 10 loops, best of 3: 129 ms per loop  in [19]: %timeit vectorized_tdot(v,a,w,x) 1000 loops, best of 3: 397 µs per loop 

case #2 (bigger datasizes) :

in [20]: # input params     ...: i,j,k,l,m,n = 15,15,15,15,15,15     ...: s,t = 3,3 # problem states : "both s , t small".     ...:      ...: # input arrays     ...: v = np.random.rand(i,j,k)     ...: = np.random.rand(i,s,l)     ...: w = np.random.rand(j,s,t,m)     ...: x = np.random.rand(k,t,n)     ...:   in [21]: np.allclose(einsum_based(v, a, w, x),vectorized_tdot(v, a, w, x)) out[21]: true  in [22]: %timeit einsum_based(v,a,w,x) 1 loops, best of 3: 1.35 s per loop  in [23]: %timeit vectorized_tdot(v,a,w,x) 1000 loops, best of 3: 1.52 ms per loop 

Comments

Popular posts from this blog

java - Run spring boot application error: Cannot instantiate interface org.springframework.context.ApplicationListener -

reactjs - React router and this.props.children - how to pass state to this.props.children -

Excel VBA "Microsoft Windows Common Controls 6.0 (SP6)" Location Changes -