python - theano.scan: Non-unit value on shape on a broadcastable dimension -


i developing simple program uses theano.scan function loop through array of vectors (in end, intend develop lstm layer program). problem error non-unit on shape of broadcastable dimension when compiling function. believe updates param cause because long put there compile function, error occur. here code:

import theano import theano.tensor t utils import * import numpy np  class lstm:     def __init__(self, x, in_size, out_size):         self.x = x         self.in_size = in_size         self.out_size = out_size         self.w_x = init_weights((in_size, out_size), "w_x")          def _active(x, pre_h):             x = t.reshape(x, (1, in_size))             pre_h = t.dot(x, self.w_x)             return pre_h          h, updates = theano.scan(_active, sequences=x,             outputs_info = [t.alloc(floatx(0.), 1, out_size)])          self.activation = h  if __name__ == "__main__":     x = t.matrix('x')     in_size = 2     out_size = 4     lstm = lstm(x, in_size, out_size)     value = lstm.activation     cost = t.mean(value)     params = [lstm.w_x]      updates = []     p in params:         gp = t.grad(cost, p)         updates.append((p, p - 0.1*gp))      f = theano.function([x], outputs = cost, updates=updates)      test = f(np.random.rand(10, in_size))     print test 

in code, use functions loaded utils.py, this:

#pylint: skip-file import numpy np import theano import theano.tensor t  def floatx(x):     return np.asarray(x, dtype=theano.config.floatx)  def init_weights(shape, name):     return theano.shared(floatx(np.random.randn(*shape) * 0.1), name)  def init_gradws(shape, name):     return theano.shared(floatx(np.zeros(shape)), name)  def init_bias(size, name):     return theano.shared(floatx(np.zeros((size,))), name) 

i have been searching while not find solution problem. in addition, can not see problem code. if not use theano.scan, code running fine.

can see problem in code? have advice solve problem?

thank in advance


Comments

Popular posts from this blog

Spring Boot + JPA + Hibernate: Unable to locate persister -

go - Golang: panic: runtime error: invalid memory address or nil pointer dereference using bufio.Scanner -

c - double free or corruption (fasttop) -