Search code examples
pythongradient-descentjax

finding the maximum of a function using jax


I have a function which I would like to find its maximum by optimizing two of its variables using Jax.

The current code that I have currently, which does not work, reads

import jax.numpy as jnp
import jax 
import scipy
import numpy as np

def temp_func(x,y,z):
    tmp = x + jnp.dot( jnp.power(y,3), jnp.tanh(z) )
    return -tmp
def obj_func(xy, z):
    x,y = xy[:2], xy[2:].reshape(2,2)
    return jnp.sum(temp_func(jnp.array(x),jnp.array(y),z))

grad_tmp = jax.grad(obj_func, argnums=0) # x,y

xy = jnp.concatenate([np.random.rand(2), np.random.rand(2*2) ])
z= jnp.array( np.random.rand(2,2) )
print(obj_func(xy,z))

result = scipy.optimize.minimize(obj_func,
                                 xy,
                                 args=(z,),
                                 method='L-BFGS-B',
                                 jac=grad_tmp
                                )

With this code, I get the error ValueError: failed in converting 7th argument g' of _lbfgsb.setulb to C/Fortran array` Do you have any suggestions to resolve the issue?


Solution

  • You might think about using the jax version of scipy.optimize.minimize, which will automatically compute and use the derivative:

    import jax.scipy.optimize
    result = jax.scipy.optimize.minimize(obj_func, xy, args=(z,), method='BFGS')
    

    That said, the results in either case are not going to be very meaningful, because your objective function is linearly decreasing in x and y, so it will be minimized when x, y → ∞