1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one or more
3  * contributor license agreements.  See the NOTICE file distributed with
4  * this work for additional information regarding copyright ownership.
5  * The ASF licenses this file to You under the Apache License, Version 2.0
6  * (the "License"); you may not use this file except in compliance with
7  * the License.  You may obtain a copy of the License at
8  *
9  *      http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17 
18 package org.apache.commons.math3.ml.neuralnet;
19 
20 import java.io.Serializable;
21 import java.io.ObjectInputStream;
22 import java.util.NoSuchElementException;
23 import java.util.List;
24 import java.util.ArrayList;
25 import java.util.Set;
26 import java.util.HashSet;
27 import java.util.Collection;
28 import java.util.Iterator;
29 import java.util.Comparator;
30 import java.util.Collections;
31 import java.util.Map;
32 import java.util.concurrent.ConcurrentHashMap;
33 import java.util.concurrent.atomic.AtomicLong;
34 import org.apache.commons.math3.exception.DimensionMismatchException;
35 import org.apache.commons.math3.exception.MathIllegalStateException;
36 
37 /**
38  * Neural network, composed of {@link Neuron} instances and the links
39  * between them.
40  *
41  * Although updating a neuron's state is thread-safe, modifying the
42  * network's topology (adding or removing links) is not.
43  *
44  * @since 3.3
45  */
46 public class Network
47     implements Iterable<Neuron>,
48                Serializable {
49     /** Serializable. */
50     private static final long serialVersionUID = 20130207L;
51     /** Neurons. */
52     private final ConcurrentHashMap<Long, Neuron> neuronMap
53         = new ConcurrentHashMap<Long, Neuron>();
54     /** Next available neuron identifier. */
55     private final AtomicLong nextId;
56     /** Neuron's features set size. */
57     private final int featureSize;
58     /** Links. */
59     private final ConcurrentHashMap<Long, Set<Long>> linkMap
60         = new ConcurrentHashMap<Long, Set<Long>>();
61 
62     /**
63      * Comparator that prescribes an order of the neurons according
64      * to the increasing order of their identifier.
65      */
66     public static class NeuronIdentifierComparator
67         implements Comparator<Neuron>,
68                    Serializable {
69         /** Version identifier. */
70         private static final long serialVersionUID = 20130207L;
71 
72         /** {@inheritDoc} */
compare(Neuron a, Neuron b)73         public int compare(Neuron a,
74                            Neuron b) {
75             final long aId = a.getIdentifier();
76             final long bId = b.getIdentifier();
77             return aId < bId ? -1 :
78                 aId > bId ? 1 : 0;
79         }
80     }
81 
82     /**
83      * Constructor with restricted access, solely used for deserialization.
84      *
85      * @param nextId Next available identifier.
86      * @param featureSize Number of features.
87      * @param neuronList Neurons.
88      * @param neighbourIdList Links associated to each of the neurons in
89      * {@code neuronList}.
90      * @throws MathIllegalStateException if an inconsistency is detected
91      * (which probably means that the serialized form has been corrupted).
92      */
Network(long nextId, int featureSize, Neuron[] neuronList, long[][] neighbourIdList)93     Network(long nextId,
94             int featureSize,
95             Neuron[] neuronList,
96             long[][] neighbourIdList) {
97         final int numNeurons = neuronList.length;
98         if (numNeurons != neighbourIdList.length) {
99             throw new MathIllegalStateException();
100         }
101 
102         for (int i = 0; i < numNeurons; i++) {
103             final Neuron n = neuronList[i];
104             final long id = n.getIdentifier();
105             if (id >= nextId) {
106                 throw new MathIllegalStateException();
107             }
108             neuronMap.put(id, n);
109             linkMap.put(id, new HashSet<Long>());
110         }
111 
112         for (int i = 0; i < numNeurons; i++) {
113             final long aId = neuronList[i].getIdentifier();
114             final Set<Long> aLinks = linkMap.get(aId);
115             for (Long bId : neighbourIdList[i]) {
116                 if (neuronMap.get(bId) == null) {
117                     throw new MathIllegalStateException();
118                 }
119                 addLinkToLinkSet(aLinks, bId);
120             }
121         }
122 
123         this.nextId = new AtomicLong(nextId);
124         this.featureSize = featureSize;
125     }
126 
127     /**
128      * @param initialIdentifier Identifier for the first neuron that
129      * will be added to this network.
130      * @param featureSize Size of the neuron's features.
131      */
Network(long initialIdentifier, int featureSize)132     public Network(long initialIdentifier,
133                    int featureSize) {
134         nextId = new AtomicLong(initialIdentifier);
135         this.featureSize = featureSize;
136     }
137 
138     /**
139      * Performs a deep copy of this instance.
140      * Upon return, the copied and original instances will be independent:
141      * Updating one will not affect the other.
142      *
143      * @return a new instance with the same state as this instance.
144      * @since 3.6
145      */
copy()146     public synchronized Network copy() {
147         final Network copy = new Network(nextId.get(),
148                                          featureSize);
149 
150 
151         for (Map.Entry<Long, Neuron> e : neuronMap.entrySet()) {
152             copy.neuronMap.put(e.getKey(), e.getValue().copy());
153         }
154 
155         for (Map.Entry<Long, Set<Long>> e : linkMap.entrySet()) {
156             copy.linkMap.put(e.getKey(), new HashSet<Long>(e.getValue()));
157         }
158 
159         return copy;
160     }
161 
162     /**
163      * {@inheritDoc}
164      */
iterator()165     public Iterator<Neuron> iterator() {
166         return neuronMap.values().iterator();
167     }
168 
169     /**
170      * Creates a list of the neurons, sorted in a custom order.
171      *
172      * @param comparator {@link Comparator} used for sorting the neurons.
173      * @return a list of neurons, sorted in the order prescribed by the
174      * given {@code comparator}.
175      * @see NeuronIdentifierComparator
176      */
getNeurons(Comparator<Neuron> comparator)177     public Collection<Neuron> getNeurons(Comparator<Neuron> comparator) {
178         final List<Neuron> neurons = new ArrayList<Neuron>();
179         neurons.addAll(neuronMap.values());
180 
181         Collections.sort(neurons, comparator);
182 
183         return neurons;
184     }
185 
186     /**
187      * Creates a neuron and assigns it a unique identifier.
188      *
189      * @param features Initial values for the neuron's features.
190      * @return the neuron's identifier.
191      * @throws DimensionMismatchException if the length of {@code features}
192      * is different from the expected size (as set by the
193      * {@link #Network(long,int) constructor}).
194      */
createNeuron(double[] features)195     public long createNeuron(double[] features) {
196         if (features.length != featureSize) {
197             throw new DimensionMismatchException(features.length, featureSize);
198         }
199 
200         final long id = createNextId();
201         neuronMap.put(id, new Neuron(id, features));
202         linkMap.put(id, new HashSet<Long>());
203         return id;
204     }
205 
206     /**
207      * Deletes a neuron.
208      * Links from all neighbours to the removed neuron will also be
209      * {@link #deleteLink(Neuron,Neuron) deleted}.
210      *
211      * @param neuron Neuron to be removed from this network.
212      * @throws NoSuchElementException if {@code n} does not belong to
213      * this network.
214      */
deleteNeuron(Neuron neuron)215     public void deleteNeuron(Neuron neuron) {
216         final Collection<Neuron> neighbours = getNeighbours(neuron);
217 
218         // Delete links to from neighbours.
219         for (Neuron n : neighbours) {
220             deleteLink(n, neuron);
221         }
222 
223         // Remove neuron.
224         neuronMap.remove(neuron.getIdentifier());
225     }
226 
227     /**
228      * Gets the size of the neurons' features set.
229      *
230      * @return the size of the features set.
231      */
getFeaturesSize()232     public int getFeaturesSize() {
233         return featureSize;
234     }
235 
236     /**
237      * Adds a link from neuron {@code a} to neuron {@code b}.
238      * Note: the link is not bi-directional; if a bi-directional link is
239      * required, an additional call must be made with {@code a} and
240      * {@code b} exchanged in the argument list.
241      *
242      * @param a Neuron.
243      * @param b Neuron.
244      * @throws NoSuchElementException if the neurons do not exist in the
245      * network.
246      */
addLink(Neuron a, Neuron b)247     public void addLink(Neuron a,
248                         Neuron b) {
249         final long aId = a.getIdentifier();
250         final long bId = b.getIdentifier();
251 
252         // Check that the neurons belong to this network.
253         if (a != getNeuron(aId)) {
254             throw new NoSuchElementException(Long.toString(aId));
255         }
256         if (b != getNeuron(bId)) {
257             throw new NoSuchElementException(Long.toString(bId));
258         }
259 
260         // Add link from "a" to "b".
261         addLinkToLinkSet(linkMap.get(aId), bId);
262     }
263 
264     /**
265      * Adds a link to neuron {@code id} in given {@code linkSet}.
266      * Note: no check verifies that the identifier indeed belongs
267      * to this network.
268      *
269      * @param linkSet Neuron identifier.
270      * @param id Neuron identifier.
271      */
addLinkToLinkSet(Set<Long> linkSet, long id)272     private void addLinkToLinkSet(Set<Long> linkSet,
273                                   long id) {
274         linkSet.add(id);
275     }
276 
277     /**
278      * Deletes the link between neurons {@code a} and {@code b}.
279      *
280      * @param a Neuron.
281      * @param b Neuron.
282      * @throws NoSuchElementException if the neurons do not exist in the
283      * network.
284      */
deleteLink(Neuron a, Neuron b)285     public void deleteLink(Neuron a,
286                            Neuron b) {
287         final long aId = a.getIdentifier();
288         final long bId = b.getIdentifier();
289 
290         // Check that the neurons belong to this network.
291         if (a != getNeuron(aId)) {
292             throw new NoSuchElementException(Long.toString(aId));
293         }
294         if (b != getNeuron(bId)) {
295             throw new NoSuchElementException(Long.toString(bId));
296         }
297 
298         // Delete link from "a" to "b".
299         deleteLinkFromLinkSet(linkMap.get(aId), bId);
300     }
301 
302     /**
303      * Deletes a link to neuron {@code id} in given {@code linkSet}.
304      * Note: no check verifies that the identifier indeed belongs
305      * to this network.
306      *
307      * @param linkSet Neuron identifier.
308      * @param id Neuron identifier.
309      */
deleteLinkFromLinkSet(Set<Long> linkSet, long id)310     private void deleteLinkFromLinkSet(Set<Long> linkSet,
311                                        long id) {
312         linkSet.remove(id);
313     }
314 
315     /**
316      * Retrieves the neuron with the given (unique) {@code id}.
317      *
318      * @param id Identifier.
319      * @return the neuron associated with the given {@code id}.
320      * @throws NoSuchElementException if the neuron does not exist in the
321      * network.
322      */
getNeuron(long id)323     public Neuron getNeuron(long id) {
324         final Neuron n = neuronMap.get(id);
325         if (n == null) {
326             throw new NoSuchElementException(Long.toString(id));
327         }
328         return n;
329     }
330 
331     /**
332      * Retrieves the neurons in the neighbourhood of any neuron in the
333      * {@code neurons} list.
334      * @param neurons Neurons for which to retrieve the neighbours.
335      * @return the list of neighbours.
336      * @see #getNeighbours(Iterable,Iterable)
337      */
getNeighbours(Iterable<Neuron> neurons)338     public Collection<Neuron> getNeighbours(Iterable<Neuron> neurons) {
339         return getNeighbours(neurons, null);
340     }
341 
342     /**
343      * Retrieves the neurons in the neighbourhood of any neuron in the
344      * {@code neurons} list.
345      * The {@code exclude} list allows to retrieve the "concentric"
346      * neighbourhoods by removing the neurons that belong to the inner
347      * "circles".
348      *
349      * @param neurons Neurons for which to retrieve the neighbours.
350      * @param exclude Neurons to exclude from the returned list.
351      * Can be {@code null}.
352      * @return the list of neighbours.
353      */
getNeighbours(Iterable<Neuron> neurons, Iterable<Neuron> exclude)354     public Collection<Neuron> getNeighbours(Iterable<Neuron> neurons,
355                                             Iterable<Neuron> exclude) {
356         final Set<Long> idList = new HashSet<Long>();
357 
358         for (Neuron n : neurons) {
359             idList.addAll(linkMap.get(n.getIdentifier()));
360         }
361         if (exclude != null) {
362             for (Neuron n : exclude) {
363                 idList.remove(n.getIdentifier());
364             }
365         }
366 
367         final List<Neuron> neuronList = new ArrayList<Neuron>();
368         for (Long id : idList) {
369             neuronList.add(getNeuron(id));
370         }
371 
372         return neuronList;
373     }
374 
375     /**
376      * Retrieves the neighbours of the given neuron.
377      *
378      * @param neuron Neuron for which to retrieve the neighbours.
379      * @return the list of neighbours.
380      * @see #getNeighbours(Neuron,Iterable)
381      */
getNeighbours(Neuron neuron)382     public Collection<Neuron> getNeighbours(Neuron neuron) {
383         return getNeighbours(neuron, null);
384     }
385 
386     /**
387      * Retrieves the neighbours of the given neuron.
388      *
389      * @param neuron Neuron for which to retrieve the neighbours.
390      * @param exclude Neurons to exclude from the returned list.
391      * Can be {@code null}.
392      * @return the list of neighbours.
393      */
getNeighbours(Neuron neuron, Iterable<Neuron> exclude)394     public Collection<Neuron> getNeighbours(Neuron neuron,
395                                             Iterable<Neuron> exclude) {
396         final Set<Long> idList = linkMap.get(neuron.getIdentifier());
397         if (exclude != null) {
398             for (Neuron n : exclude) {
399                 idList.remove(n.getIdentifier());
400             }
401         }
402 
403         final List<Neuron> neuronList = new ArrayList<Neuron>();
404         for (Long id : idList) {
405             neuronList.add(getNeuron(id));
406         }
407 
408         return neuronList;
409     }
410 
411     /**
412      * Creates a neuron identifier.
413      *
414      * @return a value that will serve as a unique identifier.
415      */
createNextId()416     private Long createNextId() {
417         return nextId.getAndIncrement();
418     }
419 
420     /**
421      * Prevents proxy bypass.
422      *
423      * @param in Input stream.
424      */
readObject(ObjectInputStream in)425     private void readObject(ObjectInputStream in) {
426         throw new IllegalStateException();
427     }
428 
429     /**
430      * Custom serialization.
431      *
432      * @return the proxy instance that will be actually serialized.
433      */
writeReplace()434     private Object writeReplace() {
435         final Neuron[] neuronList = neuronMap.values().toArray(new Neuron[0]);
436         final long[][] neighbourIdList = new long[neuronList.length][];
437 
438         for (int i = 0; i < neuronList.length; i++) {
439             final Collection<Neuron> neighbours = getNeighbours(neuronList[i]);
440             final long[] neighboursId = new long[neighbours.size()];
441             int count = 0;
442             for (Neuron n : neighbours) {
443                 neighboursId[count] = n.getIdentifier();
444                 ++count;
445             }
446             neighbourIdList[i] = neighboursId;
447         }
448 
449         return new SerializationProxy(nextId.get(),
450                                       featureSize,
451                                       neuronList,
452                                       neighbourIdList);
453     }
454 
455     /**
456      * Serialization.
457      */
458     private static class SerializationProxy implements Serializable {
459         /** Serializable. */
460         private static final long serialVersionUID = 20130207L;
461         /** Next identifier. */
462         private final long nextId;
463         /** Number of features. */
464         private final int featureSize;
465         /** Neurons. */
466         private final Neuron[] neuronList;
467         /** Links. */
468         private final long[][] neighbourIdList;
469 
470         /**
471          * @param nextId Next available identifier.
472          * @param featureSize Number of features.
473          * @param neuronList Neurons.
474          * @param neighbourIdList Links associated to each of the neurons in
475          * {@code neuronList}.
476          */
SerializationProxy(long nextId, int featureSize, Neuron[] neuronList, long[][] neighbourIdList)477         SerializationProxy(long nextId,
478                            int featureSize,
479                            Neuron[] neuronList,
480                            long[][] neighbourIdList) {
481             this.nextId = nextId;
482             this.featureSize = featureSize;
483             this.neuronList = neuronList;
484             this.neighbourIdList = neighbourIdList;
485         }
486 
487         /**
488          * Custom serialization.
489          *
490          * @return the {@link Network} for which this instance is the proxy.
491          */
readResolve()492         private Object readResolve() {
493             return new Network(nextId,
494                                featureSize,
495                                neuronList,
496                                neighbourIdList);
497         }
498     }
499 }
500