vendredi 18 décembre 2015

Utiliser le Gesture Recognition Toolkit (GRT) sur Android via JNI

Une fois GRT compilé pour Android la partie la plus complexe consiste à utiliser SWIG afin de générer automatiquement l'interface JNI pour appeler le C++ via Java. Cette page est une traduction et adaptation de ce billet qui m'a fait découvrir SWIG. La procédure à suivre est la suivante:
Commencer par créer le répertoire qui accueillera les fichiers JNI générés par SWIG.
mkdir -p ./GRTApp/app/src/main/java/org/swig/grt/ Ensuite on va avoir besoin du fichier grt.i ci dessous qui définit comment SWIG doit écrire les interfaces C++. C'est la première fois que j'utilise SWIG et trouver la syntaxe pour les opérateurs C++ surchargés n'a pas été simple.
%module grt

%{
#include "GRT.h"
using namespace GRT;
%}

%inline %{
typedef unsigned int UINT;
%}

%rename(assign) GRT::IndexDist::operator=;
%rename(assign) GRT::DTW::operator=;

%include
%include "my_vector.i"
%template(VectorDouble) std::vector;
%include "grt/Util/GRTTypedefs.h"

%include "grt/Util/Matrix.h"
%rename(assign) GRT::MatrixDouble::operator=;
%include "grt/Util/MatrixDouble.h"

%rename(assign) GRT::TimeSeriesClassificationSample::operator=;
%rename(get) GRT::TimeSeriesClassificationSample::operator[];
%include "grt/DataStructures/TimeSeriesClassificationSample.h"

%rename(assign) GRT::TimeSeriesClassificationData::operator=;
%rename(get) GRT::TimeSeriesClassificationData::operator[];
%include "grt/DataStructures/TimeSeriesClassificationData.h"

%rename(assign) GRT::TimeSeriesClassificationDataStream::operator=;
%rename(get) GRT::TimeSeriesClassificationDataStream::operator[];
%include "grt/DataStructures/TimeSeriesClassificationDataStream.h"

%include "grt/CoreModules/Classifier.h"
%include "grt/ClassificationModules/DTW/DTW.h"
Ensuite on lance la commande qui va générer jni/grt_wrap.cpp et tout un tas de fichiers d'interfaces dans org/swig/grt
swig -c++ -java -package org.swig.grt -outdir ../java/org/swig/grt/ -o grt_wrap.cpp grt.i Reste à compiler le tout avec:
Compiler avec ndk-build NDK_LIBS_OUT=../jniLibs On peut ensuite réaliser en java un clone du code d'example de l'algorithme DTW de la page suivante: Voici le code java équivalent.
DTW dtw = new DTW();

TimeSeriesClassificationData trainingData = new TimeSeriesClassificationData();

if(!trainingData.loadDatasetFromFile("/data/data/com.codeflakes.grtapp/files/dtwtrainingdata_grt")) {
    Log.e(TAG, "Failed to load training data!");
}

Log.d(TAG, trainingData.getStatsAsString());

TimeSeriesClassificationData testData = trainingData.partition(80);
dtw.enableTrimTrainingData(true, 0.1, 90);

//Train the classifier
if( !dtw.train_( trainingData ) ){
     Log.e(TAG, "Failed to train classifier!");
}

//Use the test dataset to test the DTW model
double accuracy = 0;
for(int i=0; i<testData.getNumSamples(); i++){
    //Get the i'th test sample - this is a timeseries
    long classLabel = testData.get(i).getClassLabel();
    MatrixDouble timeseries = testData.get(i).getData();
    //Log.d(TAG, "" + classLabel + " " + testData.get(i).getLength() + " " + testData.get(i).getNumDimensions());

    //Perform a prediction using the classifier
    if (!dtw.predict_( timeseries ) ){
        Log.e(TAG, "Failed to perform prediction for test sample: " + i);
    }

    //Get the predicted class label
    long predictedClassLabel = dtw.getPredictedClassLabel();
    double maximumLikelihood = dtw.getMaximumLikelihood();
    VectorDouble classLikelihoods = dtw.getClassLikelihoods();
    VectorDouble classDistances = dtw.getClassDistances();

    //Update the accuracy
    if( classLabel == predictedClassLabel ) accuracy++;

    Log.d(TAG, "TestSample: " + i + "\tClassLabel: " + classLabel + "\tPredictedClassLabel: " + predictedClassLabel + "\tMaximumLikelihood: " + maximumLikelihood);
}

Log.d(TAG, "Test Accuracy: " + (accuracy/(testData.getNumSamples())*100.0) + "%");

Aucun commentaire:

Enregistrer un commentaire