Merge remote-tracking branch 'origin/master'
This commit is contained in:
@@ -19,15 +19,58 @@ import numpy.linalg
|
|||||||
from numpy import dot, trace
|
from numpy import dot, trace
|
||||||
from numpy.linalg import det, inv
|
from numpy.linalg import det, inv
|
||||||
|
|
||||||
|
MATMUL_USE_BLAS = False
|
||||||
|
|
||||||
def matmul(*Mats):
|
def matmul(*Mats, **opts):
|
||||||
"""Do successive matrix product. For example,
|
"""Do successive matrix product. For example,
|
||||||
matmul(A,B,C,D)
|
matmul(A,B,C,D)
|
||||||
will evaluate a matrix multiplication ((A*B)*C)*D .
|
will evaluate a matrix multiplication ((A*B)*C)*D .
|
||||||
The matrices must be of matching sizes."""
|
The matrices must be of matching sizes."""
|
||||||
p = numpy.dot(Mats[0], Mats[1])
|
from numpy import asarray, dot, iscomplexobj
|
||||||
for M in Mats[2:]:
|
use_blas = opts.get('use_blas', MATMUL_USE_BLAS)
|
||||||
p = numpy.dot(p, M)
|
debug = opts.get('debug', True)
|
||||||
|
if debug:
|
||||||
|
def dbg(msg):
|
||||||
|
print msg,
|
||||||
|
else:
|
||||||
|
def dbg(msg):
|
||||||
|
pass
|
||||||
|
if use_blas:
|
||||||
|
try:
|
||||||
|
from scipy.linalg.blas import zgemm, dgemm
|
||||||
|
except:
|
||||||
|
# Older scipy (<= 0.10?)
|
||||||
|
from scipy.linalg.blas import fblas
|
||||||
|
zgemm = fblas.zgemm
|
||||||
|
dgemm = fblas.dgemm
|
||||||
|
|
||||||
|
if not use_blas:
|
||||||
|
p = dot(Mats[0], Mats[1])
|
||||||
|
for M in Mats[2:]:
|
||||||
|
p = dot(p, M)
|
||||||
|
else:
|
||||||
|
dbg("Using BLAS\n")
|
||||||
|
# FIXME: Right now only supporting double precision arithmetic.
|
||||||
|
M0 = asarray(Mats[0])
|
||||||
|
M1 = asarray(Mats[1])
|
||||||
|
if iscomplexobj(M0) or iscomplexobj(M1):
|
||||||
|
p = zgemm(alpha=1.0, a=M0, b=M1)
|
||||||
|
Cplx = True
|
||||||
|
dbg("- zgemm ")
|
||||||
|
else:
|
||||||
|
p = dgemm(alpha=1.0, a=M0, b=M1)
|
||||||
|
Cplx = False
|
||||||
|
dbg("- dgemm ")
|
||||||
|
for M in Mats[2:]:
|
||||||
|
M2 = asarray(M)
|
||||||
|
if Cplx or iscomplexobj(M2):
|
||||||
|
p = zgemm(alpha=1.0, a=p, b=M2)
|
||||||
|
Cplx = True
|
||||||
|
dbg(" zgemm")
|
||||||
|
else:
|
||||||
|
p = dgemm(alpha=1.0, a=p, b=M2)
|
||||||
|
dbg(" dgemm")
|
||||||
|
dbg("\n")
|
||||||
return p
|
return p
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user