Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 9 additions & 11 deletions pymoo/gradient/grad_jax.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
from functools import partial

import jax

jax.config.update("jax_enable_x64", True)
from jax import vjp
from jax import vmap
import numpy as np
from jax import vjp, vmap
from jax._src.api import _jacrev_unravel, _std_basis
from jax.tree_util import (tree_map)

from jax.tree_util import tree_map

import pymoo.gradient.toolbox as anp
import numpy as np


def jax_elementwise_value_and_grad(f, x):
out, pullback = vjp(f, x)
u = _std_basis(out)
jac, = vmap(pullback, in_axes=0)(u)
(jac,) = vmap(pullback, in_axes=0)(u)

grad = tree_map(partial(_jacrev_unravel, out), x, jac)

Expand All @@ -36,14 +35,13 @@ def jax_vectorized_value_and_grad(f, x):

n, m = v.shape
a = np.zeros((ncols, n, m))
for i in range(m):
cols[k].append(cnt)
a[cnt, :, i] = 1.0
cnt += 1
cols[k].extend(range(cnt, cnt + m))
a[cnt : cnt + m, :, :] = np.eye(m)[:, np.newaxis, :]
cnt += m

u[k] = anp.array(a)

jac, = vmap(pullback, in_axes=0)(u)
(jac,) = vmap(pullback, in_axes=0)(u)
jac = np.array(jac)

grad = {k: np.swapaxes(jac[I], 0, 1) for k, I in cols.items()}
Expand Down
13 changes: 5 additions & 8 deletions pymoo/operators/crossover/pcx.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,11 @@ def pcx(X, eta, zeta, index, random_state=None):
dist_to_centroid = np.maximum(eps, dist_to_centroid)

# orthogonal directions are computed
orth_dir = np.zeros_like(dist_to_index)

for i in range(n_parents):
if i != index:
temp1 = (diff_to_index[i] * diff_to_centroid).sum(axis=-1)
temp2 = temp1 / (dist_to_index[i] * dist_to_centroid)
temp3 = np.maximum(0.0, 1.0 - temp2 ** 2)
orth_dir[i] = dist_to_index[i] * (temp3 ** 0.5)
temp1 = (diff_to_index * diff_to_centroid).sum(axis=-1)
temp2 = temp1 / (dist_to_index * dist_to_centroid)
temp3 = np.maximum(0.0, 1.0 - temp2**2)
orth_dir = dist_to_index * (temp3**0.5)
orth_dir[index] = 0.0

# this is the avg of the perpendicular distances from other parents to the parent k
D_not = orth_dir.sum(axis=0) / (n_parents - 1)
Expand Down
11 changes: 3 additions & 8 deletions pymoo/operators/sampling/rnd.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,11 @@ def _do(self, problem, n_samples, *args, random_state=None, **kwargs):


class IntegerRandomSampling(FloatRandomSampling):

def _do(self, problem, n_samples, *args, random_state=None, **kwargs):
n, (xl, xu) = problem.n_var, problem.bounds()
return np.column_stack([random_state.integers(xl[k], xu[k] + 1, size=n_samples) for k in range(n)])
xl, xu = problem.bounds()
return random_state.integers(xl, xu + 1, size=(n_samples, problem.n_var))


class PermutationRandomSampling(Sampling):

def _do(self, problem, n_samples, *args, random_state=None, **kwargs):
X = np.full((n_samples, problem.n_var), 0, dtype=int)
for i in range(n_samples):
X[i, :] = random_state.permutation(problem.n_var)
return X
return np.argsort(random_state.random((n_samples, problem.n_var)), axis=1)
15 changes: 3 additions & 12 deletions pymoo/util/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,7 @@ def swap(M, a, b):

# repairs a numpy array to be in bounds
def repair(X, xl, xu):
larger_than_xu = X[0, :] > xu
X[0, larger_than_xu] = xu[larger_than_xu]

smaller_than_xl = X[0, :] < xl
X[0, smaller_than_xl] = xl[smaller_than_xl]

X[0] = np.clip(X[0], xl, xu)
return X


Expand All @@ -58,12 +53,8 @@ def parameter_less_constraints(F, CV, F_max=None):

@default_random_state
def random_permutations(n, l, concat=True, random_state=None):
P = []
for i in range(n):
P.append(random_state.permutation(l))
if concat:
P = np.concatenate(P)
return P
P = np.argsort(random_state.random((n, l)), axis=1)
return P.reshape(-1) if concat else P


def get_duplicates(M):
Expand Down