00001 #ifndef __NEURALNETWORK_H
00002 #define __NEURALNETWORK_H
00003
00004 #include <vector>
00005 #include <valarray>
00006
00007 #include "../Controller.h"
00008 #include "../GenotypeDecoder.h"
00009 #include "../../utils/Graph.h"
00010 #include "Neuron.h"
00011
00016 namespace Teem
00017 {
00020 class GraphNeuralNetwork : public Controller
00021 {
00022 public:
00024 GraphNeuralNetwork(const Graph &graph, const std::string &root = "teem.nn");
00026 virtual ~GraphNeuralNetwork();
00027
00028
00029 virtual void setInput(unsigned index, double val);
00030 virtual double getOutput(unsigned index);
00031 virtual void step();
00032
00033
00035 Neuron* getNeuron(size_t index) { return neurons[index]; }
00037 Neuron* getNeuron(const std::string &tag, size_t index) { return neuronLists[tag][index]; }
00039 std::valarray<Neuron*>& getTaggedNeurons(const std::string &tag) { return neuronLists[tag]; }
00041 size_t neuronNum() const { return neurons.size(); }
00043 size_t neuronNum(const std::string &tag) const { return neuronLists.find(tag)->second.size(); }
00045 Synapse* getSynapse(size_t index) { return synapses[index]; }
00047 Synapse* getSynapse(const std::string &tag, size_t index) { return synapseLists[tag][index]; }
00049 std::valarray<Synapse*>& getTaggedSynapses(const std::string &tag) { return synapseLists[tag]; }
00051 size_t synapseNum() const { return synapses.size(); }
00053 size_t synapseNum(const std::string &tag) const { return synapseLists.find(tag)->second.size(); }
00054
00055 protected:
00057 std::valarray<Neuron*> neurons;
00059 std::valarray<Synapse*> synapses;
00060
00062 std::map<std::string, std::valarray<Neuron*> > neuronLists;
00064 std::map<std::string, std::valarray<Synapse*> > synapseLists;
00065
00068 std::valarray<InputNeuron*> inputNeurons;
00070 std::valarray<Neuron*>* outputNeurons;
00071 };
00072
00073
00081 class GraphFeedForwardGenotypeDecoder : public GraphGenotypeDecoder
00082 {
00083 protected:
00084 Ishtar::Variable<unsigned> hiddenCount;
00085 Ishtar::Variable<double> weightRange;
00086
00087 public:
00089 GraphFeedForwardGenotypeDecoder(const std::string &root) :
00090 GraphGenotypeDecoder(root),
00091 hiddenCount(root + ".hiddenCount", 0),
00092 weightRange(root + ".weightRange", 2.0) { }
00094 virtual ~GraphFeedForwardGenotypeDecoder() { }
00095
00096
00097 virtual Genome* createGenome(void);
00098 virtual Controller* decode(Genome *genome);
00099 };
00100
00108 class GraphFullyConnectedGenotypeDecoder : public GraphGenotypeDecoder
00109 {
00110 protected:
00111 Ishtar::Variable<unsigned> hiddenCount;
00112 Ishtar::Variable<double> weightRange;
00113
00114 public:
00116 GraphFullyConnectedGenotypeDecoder(const std::string &root) :
00117 GraphGenotypeDecoder(root),
00118 hiddenCount(root + ".hiddenCount", 0),
00119 weightRange(root + ".weightRange", 2.0) { }
00121 virtual ~GraphFullyConnectedGenotypeDecoder() { }
00122
00123
00124 virtual Genome* createGenome(void);
00125 virtual Controller* decode(Genome *genome);
00126 };
00127
00128 }
00129
00130 #endif