|
16 | 16 | from __future__ import annotations |
17 | 17 |
|
18 | 18 | import math |
19 | | -from typing import List, Union |
| 19 | +from typing import List, Union, Tuple |
20 | 20 |
|
21 | 21 | from .base_optimizer import BaseOptimizer |
22 | 22 |
|
@@ -161,45 +161,50 @@ def update( |
161 | 161 | bias_correction1 = 1 - self.beta1**self._time_step |
162 | 162 | bias_correction2 = 1 - self.beta2**self._time_step |
163 | 163 |
|
164 | | - def _adam_update_recursive(params, grads, first_moment, second_moment): |
| 164 | + def _adam_update_recursive( |
| 165 | + parameters: Union[float, List], |
| 166 | + gradients: Union[float, List], |
| 167 | + first_moment: Union[float, List], |
| 168 | + second_moment: Union[float, List] |
| 169 | + ) -> Tuple[Union[float, List], Union[float, List], Union[float, List]]: |
165 | 170 | # Handle scalar case |
166 | | - if isinstance(params, (int, float)): |
167 | | - if not isinstance(grads, (int, float)): |
| 171 | + if isinstance(parameters, (int, float)): |
| 172 | + if not isinstance(gradients, (int, float)): |
168 | 173 | raise ValueError( |
169 | 174 | "Shape mismatch: parameter is scalar but gradient is not" |
170 | 175 | ) |
171 | 176 |
|
172 | 177 | # Update first moment: m = β₁ * m + (1-β₁) * g |
173 | | - new_first_moment = self.beta1 * first_moment + (1 - self.beta1) * grads |
| 178 | + new_first_moment = self.beta1 * first_moment + (1 - self.beta1) * gradients |
174 | 179 |
|
175 | 180 | # Update second moment: v = β₂ * v + (1-β₂) * g² |
176 | 181 | new_second_moment = self.beta2 * second_moment + (1 - self.beta2) * ( |
177 | | - grads * grads |
| 182 | + gradients * gradients |
178 | 183 | ) |
179 | 184 |
|
180 | 185 | # Bias-corrected moments |
181 | 186 | m_hat = new_first_moment / bias_correction1 |
182 | 187 | v_hat = new_second_moment / bias_correction2 |
183 | 188 |
|
184 | 189 | # Parameter update: θ = θ - α * m̂ / (√v̂ + ε) |
185 | | - new_param = params - self.learning_rate * m_hat / ( |
| 190 | + new_param = parameters - self.learning_rate * m_hat / ( |
186 | 191 | math.sqrt(v_hat) + self.epsilon |
187 | 192 | ) |
188 | 193 |
|
189 | 194 | return new_param, new_first_moment, new_second_moment |
190 | 195 |
|
191 | 196 | # Handle list case |
192 | | - if len(params) != len(grads): |
| 197 | + if len(parameters) != len(gradients): |
193 | 198 | raise ValueError( |
194 | | - f"Shape mismatch: parameters length {len(params)} vs " |
195 | | - f"gradients length {len(grads)}" |
| 199 | + f"Shape mismatch: parameters length {len(parameters)} vs " |
| 200 | + f"gradients length {len(gradients)}" |
196 | 201 | ) |
197 | 202 |
|
198 | 203 | new_params = [] |
199 | 204 | new_first_moments = [] |
200 | 205 | new_second_moments = [] |
201 | 206 |
|
202 | | - for p, g, m1, m2 in zip(params, grads, first_moment, second_moment): |
| 207 | + for p, g, m1, m2 in zip(parameters, gradients, first_moment, second_moment): |
203 | 208 | if isinstance(p, list) and isinstance(g, list): |
204 | 209 | # Recursive case for nested lists |
205 | 210 | new_p, new_m1, new_m2 = _adam_update_recursive(p, g, m1, m2) |
@@ -309,11 +314,11 @@ def __str__(self) -> str: |
309 | 314 | x_adagrad = [-1.0, 1.0] |
310 | 315 | x_adam = [-1.0, 1.0] |
311 | 316 |
|
312 | | - def rosenbrock(x, y): |
| 317 | + def rosenbrock(x: float, y: float) -> float: |
313 | 318 | """Rosenbrock function: f(x,y) = 100*(y-x²)² + (1-x)²""" |
314 | 319 | return 100 * (y - x * x) ** 2 + (1 - x) ** 2 |
315 | 320 |
|
316 | | - def rosenbrock_gradient(x, y): |
| 321 | + def rosenbrock_gradient(x: float, y: float) -> List[float]: |
317 | 322 | """Gradient of Rosenbrock function""" |
318 | 323 | df_dx = -400 * x * (y - x * x) - 2 * (1 - x) |
319 | 324 | df_dy = 200 * (y - x * x) |
|
0 commit comments