Commencer par créer le répertoire qui accueillera les fichiers JNI générés par SWIG.
mkdir -p ./GRTApp/app/src/main/java/org/swig/grt/
Ensuite on va avoir besoin du fichier grt.i ci dessous qui définit comment SWIG doit écrire les interfaces C++. C'est la première fois que j'utilise SWIG et trouver la syntaxe pour les opérateurs C++ surchargés n'a pas été simple.
%module grt
%{
#include "GRT.h"
using namespace GRT;
%}
%inline %{
typedef unsigned int UINT;
%}
%rename(assign) GRT::IndexDist::operator=;
%rename(assign) GRT::DTW::operator=;
%include
%include "my_vector.i"
%template(VectorDouble) std::vector;
%include "grt/Util/GRTTypedefs.h"
%include "grt/Util/Matrix.h"
%rename(assign) GRT::MatrixDouble::operator=;
%include "grt/Util/MatrixDouble.h"
%rename(assign) GRT::TimeSeriesClassificationSample::operator=;
%rename(get) GRT::TimeSeriesClassificationSample::operator[];
%include "grt/DataStructures/TimeSeriesClassificationSample.h"
%rename(assign) GRT::TimeSeriesClassificationData::operator=;
%rename(get) GRT::TimeSeriesClassificationData::operator[];
%include "grt/DataStructures/TimeSeriesClassificationData.h"
%rename(assign) GRT::TimeSeriesClassificationDataStream::operator=;
%rename(get) GRT::TimeSeriesClassificationDataStream::operator[];
%include "grt/DataStructures/TimeSeriesClassificationDataStream.h"
%include "grt/CoreModules/Classifier.h"
%include "grt/ClassificationModules/DTW/DTW.h"
Ensuite on lance la commande qui va générer jni/grt_wrap.cpp et tout un tas de fichiers d'interfaces dans org/swig/grtswig -c++ -java -package org.swig.grt -outdir ../java/org/swig/grt/ -o grt_wrap.cpp grt.i
Reste à compiler le tout avec:Compiler avec ndk-build NDK_LIBS_OUT=../jniLibs
On peut ensuite réaliser en java un clone du code d'example de l'algorithme DTW de la page suivante:
Voici le code java équivalent.
DTW dtw = new DTW();
TimeSeriesClassificationData trainingData = new TimeSeriesClassificationData();
if(!trainingData.loadDatasetFromFile("/data/data/com.codeflakes.grtapp/files/dtwtrainingdata_grt")) {
Log.e(TAG, "Failed to load training data!");
}
Log.d(TAG, trainingData.getStatsAsString());
TimeSeriesClassificationData testData = trainingData.partition(80);
dtw.enableTrimTrainingData(true, 0.1, 90);
//Train the classifier
if( !dtw.train_( trainingData ) ){
Log.e(TAG, "Failed to train classifier!");
}
//Use the test dataset to test the DTW model
double accuracy = 0;
for(int i=0; i<testData.getNumSamples(); i++){
//Get the i'th test sample - this is a timeseries
long classLabel = testData.get(i).getClassLabel();
MatrixDouble timeseries = testData.get(i).getData();
//Log.d(TAG, "" + classLabel + " " + testData.get(i).getLength() + " " + testData.get(i).getNumDimensions());
//Perform a prediction using the classifier
if (!dtw.predict_( timeseries ) ){
Log.e(TAG, "Failed to perform prediction for test sample: " + i);
}
//Get the predicted class label
long predictedClassLabel = dtw.getPredictedClassLabel();
double maximumLikelihood = dtw.getMaximumLikelihood();
VectorDouble classLikelihoods = dtw.getClassLikelihoods();
VectorDouble classDistances = dtw.getClassDistances();
//Update the accuracy
if( classLabel == predictedClassLabel ) accuracy++;
Log.d(TAG, "TestSample: " + i + "\tClassLabel: " + classLabel + "\tPredictedClassLabel: " + predictedClassLabel + "\tMaximumLikelihood: " + maximumLikelihood);
}
Log.d(TAG, "Test Accuracy: " + (accuracy/(testData.getNumSamples())*100.0) + "%");
Aucun commentaire:
Publier un commentaire