public function update(array $gradient)
{
if (!$this->means) {
$this->means = array_fill(0, count($gradient), 0.0);
$this->variances = array_fill(0, count($gradient), 0.0);
}
foreach ($gradient as $i => $slope) {
$this->means[$i] = $this->meanBeta * $this->means[$i] + (1.0 - $this->meanBeta) * $slope;
$this->variances[$i] = $this->varianceBeta * $this->variances[$i] + (1.0 - $this->varianceBeta) * pow($slope, 2);
}
$this->iteration++;
$this->gradient = $gradient;
}
public function testStep() { $schedule = new Adam(0.01, 1.0E-8, 0.9, 0.999); $schedule->update([5.0]); $schedule->update([5.0]); static::assertEquals(0.0019999999959999857, $schedule->step(0)); }