[go: nahoru, domu]

Skip to content

Commit

Permalink
Format optimizers
Browse files Browse the repository at this point in the history
  • Loading branch information
jaeyoo committed Mar 15, 2023
1 parent 30066f0 commit 39a811d
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 73 deletions.
42 changes: 17 additions & 25 deletions tensorflow_quantum/python/optimizers/rotosolve_minimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The rotosolve minimization algorithm"""
import collections
"""The rotosolve minimization algorithm."""
import numpy as np
import tensorflow as tf

Expand Down Expand Up @@ -46,6 +45,7 @@ def prefer_static_value(x):


class RotosolveOptimizerResults(tf.experimental.ExtensionType):
"""ExtentionType of Rotosolve Optimizer tf.while_loop() inner state."""
converged: tf.Tensor
# Scalar boolean tensor indicating whether the minimum
# was found within tolerance.
Expand All @@ -60,7 +60,7 @@ class RotosolveOptimizerResults(tf.experimental.ExtensionType):
# this value is the argmin of the objective function.
# A tensor containing the value of the objective from
# previous iteration
objective_value_previous_iteration: tf.Tensor
objective_value_prev: tf.Tensor
# Save the evaluated value of the objective function
# from the previous iteration
objective_value: tf.Tensor
Expand All @@ -78,23 +78,16 @@ class RotosolveOptimizerResults(tf.experimental.ExtensionType):
# modifying. Reserved for internal use.

def to_dict(self):
"""Transforms immutable data to mutable dictionary."""
return {
"converged":
self.converged,
"num_iterations":
self.num_iterations,
"num_objective_evaluations":
self.num_objective_evaluations,
"position":
self.position,
"objective_value":
self.objective_value,
"objective_value_previous_iteration":
self.objective_value_previous_iteration,
"tolerance":
self.tolerance,
"solve_param_i":
self.solve_param_i,
"converged": self.converged,
"num_iterations": self.num_iterations,
"num_objective_evaluations": self.num_objective_evaluations,
"position": self.position,
"objective_value": self.objective_value,
"objective_value_prev": self.objective_value_prev,
"tolerance": self.tolerance,
"solve_param_i": self.solve_param_i,
}


Expand All @@ -106,7 +99,7 @@ def _get_initial_state(initial_position, tolerance, expectation_value_function):
"num_objective_evaluations": tf.Variable(0),
"position": tf.Variable(initial_position),
"objective_value": expectation_value_function(initial_position),
"objective_value_previous_iteration": tf.Variable(0.),
"objective_value_prev": tf.Variable(0.),
"tolerance": tolerance,
"solve_param_i": tf.Variable(0),
}
Expand Down Expand Up @@ -214,7 +207,7 @@ def _rotosolve_one_parameter_once(state):
next_state_params.update({
"solve_param_i": state.solve_param_i + 1,
"position": new_position,
"objective_value_previous_iteration": state.objective_value,
"objective_value_prev": state.objective_value,
"objective_value": (expectation_value_function(new_position)),
})
return [RotosolveOptimizerResults(**next_state_params)]
Expand Down Expand Up @@ -265,10 +258,9 @@ def _body(state):
post_state = _rotosolve_all_parameters_once(pre_state)[0]
next_state_params = post_state.to_dict()
next_state_params.update({
"converged":
(tf.abs(post_state.objective_value -
post_state.objective_value_previous_iteration) <
post_state.tolerance),
"converged": (tf.abs(post_state.objective_value -
post_state.objective_value_prev) <
post_state.tolerance),
"num_iterations": post_state.num_iterations + 1,
})
return [RotosolveOptimizerResults(**next_state_params)]
Expand Down
88 changes: 40 additions & 48 deletions tensorflow_quantum/python/optimizers/spsa_minimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The SPSA minimization algorithm"""
import collections
"""The SPSA minimization algorithm."""
import tensorflow as tf
import numpy as np

Expand Down Expand Up @@ -46,6 +45,7 @@ def prefer_static_value(x):


class SPSAOptimizerResults(tf.experimental.ExtensionType):
"""ExtentionType of SPSA Optimizer tf.while_loop() inner state."""
converged: tf.Tensor
# Scalar boolean tensor indicating whether the minimum
# was found within tolerance.
Expand All @@ -60,7 +60,7 @@ class SPSAOptimizerResults(tf.experimental.ExtensionType):
# this value is the argmin of the objective function.
# A tensor containing the value of the objective from
# previous iteration
objective_value_previous_iteration: tf.Tensor
objective_value_prev: tf.Tensor
# Save the evaluated value of the objective function
# from the previous iteration
objective_value: tf.Tensor
Expand All @@ -72,7 +72,7 @@ class SPSAOptimizerResults(tf.experimental.ExtensionType):
# Define the stop criteria. Iteration will stop when the
# objective value difference between two iterations is
# smaller than tolerance
lr: tf.Tensor
learning_rate: tf.Tensor
# Specifies the learning rate
alpha: tf.Tensor
# Specifies scaling of the learning rate
Expand All @@ -89,38 +89,27 @@ class SPSAOptimizerResults(tf.experimental.ExtensionType):
# (only applies if blocking is true).

def to_dict(self):
"""Transforms immutable data to mutable dictionary."""
return {
"converged":
self.converged,
"num_iterations":
self.num_iterations,
"num_objective_evaluations":
self.num_objective_evaluations,
"position":
self.position,
"objective_value":
self.objective_value,
"objective_value_previous_iteration":
self.objective_value_previous_iteration,
"tolerance":
self.tolerance,
"lr":
self.lr,
"alpha":
self.alpha,
"perturb":
self.perturb,
"gamma":
self.gamma,
"blocking":
self.blocking,
"allowed_increase":
self.allowed_increase,
"converged": self.converged,
"num_iterations": self.num_iterations,
"num_objective_evaluations": self.num_objective_evaluations,
"position": self.position,
"objective_value": self.objective_value,
"objective_value_prev": self.objective_value_prev,
"tolerance": self.tolerance,
"learning_rate": self.learning_rate,
"alpha": self.alpha,
"perturb": self.perturb,
"gamma": self.gamma,
"blocking": self.blocking,
"allowed_increase": self.allowed_increase,
}


def _get_initial_state(initial_position, tolerance, expectation_value_function,
lr, alpha, perturb, gamma, blocking, allowed_increase):
learning_rate, alpha, perturb, gamma, blocking,
allowed_increase):
"""Create SPSAOptimizerResults with initial state of search."""
init_args = {
"converged": tf.Variable(False),
Expand All @@ -129,9 +118,9 @@ def _get_initial_state(initial_position, tolerance, expectation_value_function,
"position": tf.Variable(initial_position),
"objective_value":
(tf.cast(expectation_value_function(initial_position), tf.float32)),
"objective_value_previous_iteration": tf.Variable(np.inf),
"objective_value_prev": tf.Variable(np.inf),
"tolerance": tolerance,
"lr": tf.Variable(lr),
"learning_rate": tf.Variable(learning_rate),
"alpha": tf.Variable(alpha),
"perturb": tf.Variable(perturb),
"gamma": tf.Variable(gamma),
Expand All @@ -146,7 +135,7 @@ def minimize(expectation_value_function,
tolerance=1e-5,
max_iterations=200,
alpha=0.602,
lr=1.0,
learning_rate=1.0,
perturb=1.0,
gamma=0.101,
blocking=False,
Expand Down Expand Up @@ -188,7 +177,8 @@ def minimize(expectation_value_function,
tolerance: Scalar `tf.Tensor` of real dtype. Specifies the tolerance
for the procedure. If the supremum norm between two iteration
vector is below this number, the algorithm is stopped.
lr: Scalar `tf.Tensor` of real dtype. Specifies the learning rate
learning_rate: Scalar `tf.Tensor` of real dtype.
Specifies the learning rate.
alpha: Scalar `tf.Tensor` of real dtype. Specifies scaling of the
learning rate.
perturb: Scalar `tf.Tensor` of real dtype. Specifies the size of the
Expand Down Expand Up @@ -227,7 +217,9 @@ def minimize(expectation_value_function,
max_iterations = tf.convert_to_tensor(max_iterations,
name='max_iterations')

lr_init = tf.convert_to_tensor(lr, name='initial_a', dtype='float32')
learning_rate_init = tf.convert_to_tensor(learning_rate,
name='initial_a',
dtype='float32')
perturb_init = tf.convert_to_tensor(perturb,
name='initial_c',
dtype='float32')
Expand All @@ -253,7 +245,7 @@ def _spsa_once(state):
state.perturb * delta_shift)

gradient_estimate = (v_p - v_m) / (2 * state.perturb) * delta_shift
update = state.lr * gradient_estimate
update = state.learning_rate * gradient_estimate
next_state_params = state.to_dict()
next_state_params.update({
"num_objective_evaluations":
Expand All @@ -263,11 +255,11 @@ def _spsa_once(state):
current_obj = tf.cast(expectation_value_function(state.position -
update),
dtype=tf.float32)
if state.objective_value_previous_iteration + \
if state.objective_value_prev + \
state.allowed_increase >= current_obj or not state.blocking:
next_state_params.update({
"position": state.position - update,
"objective_value_previous_iteration": state.objective_value,
"objective_value_prev": state.objective_value,
"objective_value": current_obj
})

Expand All @@ -285,35 +277,35 @@ def _cond(state):

def _body(state):
"""Main optimization loop."""
new_lr = lr_init / (
new_learning_rate = learning_rate_init / (
(tf.cast(state.num_iterations + 1, tf.float32) +
0.01 * tf.cast(max_iterations, tf.float32))**state.alpha)
new_perturb = perturb_init / (tf.cast(state.num_iterations + 1,
tf.float32)**state.gamma)

pre_state_params = state.to_dict()
pre_state_params.update({
"lr": new_lr,
"learning_rate": new_learning_rate,
"perturb": new_perturb,
})

post_state = _spsa_once(SPSAOptimizerResults(**pre_state_params))[0]
post_state_params = post_state.to_dict()
tf.print("asdf", state.objective_value.dtype,
state.objective_value_previous_iteration.dtype)
state.objective_value_prev.dtype)
post_state_params.update({
"num_iterations":
post_state.num_iterations + 1,
"converged": (tf.abs(state.objective_value -
state.objective_value_previous_iteration) <
state.tolerance),
"converged":
(tf.abs(state.objective_value - state.objective_value_prev)
< state.tolerance),
})
return [SPSAOptimizerResults(**post_state_params)]

initial_state = _get_initial_state(initial_position, tolerance,
expectation_value_function, lr,
alpha, perturb, gamma, blocking,
allowed_increase)
expectation_value_function,
learning_rate, alpha, perturb, gamma,
blocking, allowed_increase)

return tf.while_loop(cond=_cond,
body=_body,
Expand Down

0 comments on commit 39a811d

Please sign in to comment.