public function predict(array $samples)
{
$testSet = DataTransformer::testSet($samples);
file_put_contents($testSetFileName = $this->varPath . uniqid(), $testSet);
file_put_contents($modelFileName = $testSetFileName . '-model', $this->model);
$outputFileName = $testSetFileName . '-output';
$command = sprintf('%ssvm-predict%s %s %s %s', $this->binPath, $this->getOSExtension(), $testSetFileName, $modelFileName, $outputFileName);
$output = '';
exec(escapeshellcmd($command), $output);
$predictions = file_get_contents($outputFileName);
unlink($testSetFileName);
unlink($modelFileName);
unlink($outputFileName);
if (in_array($this->type, [Type::C_SVC, Type::NU_SVC])) {
$predictions = DataTransformer::predictions($predictions, $this->labels);
} else {
$predictions = explode(PHP_EOL, trim($predictions));
}
if (!is_array($samples[0])) {
return $predictions[0];
}
return $predictions;
}
public function testPredictSampleFromMultipleClassWithRbfKernel() { $samples = [[1, 3], [1, 4], [1, 4], [3, 1], [4, 1], [4, 2], [-3, -1], [-4, -1], [-4, -2]]; $labels = ['a', 'a', 'a', 'b', 'b', 'b', 'c', 'c', 'c']; $svm = new SupportVectorMachine(Type::C_SVC, Kernel::RBF, 100.0); $svm->train($samples, $labels); $predictions = $svm->predict([[1, 5], [4, 3], [-4, -3]]); $this->assertEquals('a', $predictions[0]); $this->assertEquals('b', $predictions[1]); $this->assertEquals('c', $predictions[2]); }