Showing posts with label AI. Show all posts

Machine Learning with AWS & Scala

Recently, in an attempt to starting learning React, I started building an akka-http backend API as a starting point. I quickly got distracted building the backend and ended up integrating with both the Twitter streaming API and AWS' Comprehend sentiment analysis API - which is what this post will be about.

Similar to an old idea, where I built an app consuming tweets about the 2015 Rugby world cup, this time my app was consuming tweets about the FIFA world cup in Russia - splitting tweets by country and recording sentiment for each one (and so a rolling average sentiment for each team).


Overview

The premise was simple:

  1. Connect to the Twitter streaming API (aka the firehose) filtering on world cup related key words
  2. Pass the body of the tweet to AWS Comprehend to get the sentiment score
  3. Update the in memory store of stats (count and average sentiment) for each country

In terms of technology used:
  1. Scala & Akka-Http
  2. Twitter4s Scala client
  3. AWS Java SDK

As always, all the code is on Github - to run it locally, you will need a Twitter dev API key (add an application.conf as per the readme on the Twitter4s github) and you will also need an AWS key/secret - the code will look for credentials stored locally but you can also just set them in environment variables before starting. The free tier supports up to 50,000 Comprehend API requests in the first 12 months - and as you can imagine, plugging this directly into twitter can result in lots of calls, so make sure you restrict it (or at least monitor it) before you leave it running!


Consuming Tweets

Consuming tweets is really simple with the Twitter4s client - we just define a partial function that will handle the incoming tweet. 

The other functions about parsing countries/teams are excluded for brevity - and you can see its quite simple - each inbound tweet we make a call to the Sentiment Service (we will look at that later) then pass it with the additional data to our update service that will then store it in memory. You will also see it is ridiculously easy to start the Twitter streaming client filtering by key words.


Detecting Sentiment

Because I wanted to be able to stub out the sentiment analysis without being tied to AWS, you will notice I am using the self-type annotation on my twitter class above, which requires a SentimentModule to be passed in at construction - I am using a simple cake pattern to manage all my dependencies here. In the Github repo, there is also a Dummy implementation, that will just pick a random number for the score, so you can still see the rest of the API working - but the interesting part is the AWS integration:
Once again, the SDK makes the integration really painless - in my code I am simplifying the actual results a lot to a much cruder Positive/Neutral/Negative rating (plus a numeric score -100..100).

The AWSCredentials class is the bit that is going to look in the normal places for an AWS key.


Storing and updating our stats

So now we have our inbound tweets and a way to asses their sentiment score - I then setup a very simple akka actor to manage the state and just stored the API data in memory (if you restart the app, the store gets reset and the API stops serving data).

Again, very simple out of the box stuff for akka, but it allows easy and thread safe management of the in-memory data store. I also track a rolling list of the last twenty tweets processed, which is managed by a second, almost identical, actor.


The results

I ran the app during several games, below are some sample outputs from the API. The response from the stats API is fairly boring reading (just numbers) but the example tweets show two examples of a positive and neutral tweet correctly identified (apologies for the expletives in the tweet about Poland - I guess that fan wasn't too happy about being beaten by the Senegalese!) - you will also notice, the app captures the countries being mentioned, which exposes one flaw of the design: in the negative tweet from the Polish fan loosing two goals to Senegal, it correctly identifies the sentiment as negative, but we have no way to determine the subject - as both teams are mentioned, the app naively assigns it as a negative tweet to both of the teams, where as on reading, it is clearly negative with regards to Poland (I wasn't too concerned for my experiment, of course, just an observation worth noting).

Sample tweet from the latest API:

Sample response from the stats API:

When I finally did get around to starting to learn React, I just plugged in the APIs and paid no attention to styling, which is a round about way of apologising for the horrible appearence of the screenshot below (I'm really sorry about the css gradient)!





We need to talk about AI

Ethical and Regulatory questions facing AI





Regardless of area of expertise, most of us are probably already aware of the momentum around Artificial Intelligence (AI). Between self driving cars, home assistants (Alexa, Google Home, et al) and the growing capabilities of our mobile devices there is no escaping the ever looming presence of AI in our lives.

Furthermore, it seems unlikely that this will slow down anytime soon. A recent Narrative Science study found that AI adoption grew by 60% in the last year with 61% of organisations having reported to have implemented AI within their business, and a Gartner report predicted that by 2020 85% of customer interactions will be managed without human intervention.

But despite this growth, there is still a question mark over whether, and if so, how, the field should be regulated. Having been brought up on decades of sci-fi about AI going rogue and robots enslaving the human race, it feels like there is both the fear of this possible future, whilst also scepticism that these fears are only the stuff of movies. Elon Musk has famously warned of the future risks of AI: “I think we should be very careful about artificial intelligence. If I had to guess at what our biggest existential threat is, it’s probably that” whilst others, including Mark Zuckerberg, have downplayed the claims of doomsday scenarios as irresponsible.

So what's the big deal? AI already permeates so many aspects of life and business, but considering for a moment that these technologies could be being used to control autonomous cars on public roads, determine people’s credit score or suitability for a job, to detect illness or even in policing and judicial decision making - it is pretty clear that we should have a good understanding of these technologies and clear systems of accountability and control in place. In all these examples getting a decision wrong has the potential to ruin lives, yet there is still limited regulation, control or even understanding of the algorithms, the data and their usage.

A common analogy is with other heavily regulated industries: big pharma companies can’t release drugs without thorough testing and approval, yet several big tech companies have already started testing autonomous vehicles on public roads with limited regulatory controls (that’s not to say that they have had a completely free pass, there are varying levels of regulation, depending on the region. Arizona has long been promoting itself as an AI friendly state to try to attract business from big tech, making it as easy as possible for companies to test self driving cars with minimal regulatory friction, and they recently saw the first fatality from a self-driving car).

In its 2017 report, the AI Now Institute recommended that AI be outright ban from use in any high risk areas, such as criminal justice, healthcare, welfare and education and further measures for other domains - which given the potential impact of errors in these domains, seems like a fairly sensible starting point.



Uncertainty and the unknown


One key aspect that is especially troubling is the lack of understanding of both the data and the underlying technology. This isn’t necessarily a surprise - we have computers being trained on millions of data points, to the point of being able to outperform humans at their tasks, so it should come as no surprise that both the inner workings and the end results could be beyond easy comprehension.

This problem has been demonstrated by several high profile mishaps from large tech companies, showing that even companies that have a wealth of resources and technical expertise in the domain can be caught out - such as Microsoft’s AI chatbot Tay, who quickly became racist when released into the wild. Clearly Microsoft had neither intended nor envisaged that end result. Similarly, when Google translate revealed gender bias in pairing “he” with “hardworking” and “she” with “lazy” - it clearly wasn’t an intentional or foreseen behaviour, but eventually revealed itself with wider usage.



Understanding where bias in AI comes from


To get a better understanding of where these biases and blind spots come from, let’s take a look at how AI learns. Broadly speaking, there are three primary approaches to training AI: Supervised, Unsupervised and Reinforcement.

Unsupervised learning is where the AI is fed very large amounts of raw data - for example an entire corpus of fictional texts - and it is left to work out patterns or groupings. That is, it doesn’t know a right or wrong answer, but can identify related things from the dataset and group them together (for example, AI reading popular fiction might group together terms such as “batman” and “wonder woman”, but it would have no knowledge of what these terms actually mean).

Supervised learning is where the AI is fed very large amounts of marked up data - that is, for each input, it also gets passed the expected output. An example of this is if you had a large set of photos (say Google Photos) which are pre-tagged with descriptions of what is in the photo, the dataset could be used to train an AI to identify contents of a photo.

Reinforcement learning is similar to supervised in as much as the algorithm gets information as to whether or not it is performing well (like knowing the answer for a given input) but is in the form of a feedback loop and works more like a trial-and-error approach to learning (it might have a general fitness score function that can be used by the algorithm to determine whether or not its response to given input has been successful or not and adjust its response for the next cycle). The simplest example of this is something like AlphaGo/AlphaZero, where an algorithm learns to play a game like Go or chess by trial and error and gets feedback on its attempted response from the game itself.

Both Supervised and Unsupervised learning cases require vast amounts of data to accurately train AI, which really leads us to one of the primary challenges for building fair and ethical AI: sourcing the data to train on. AI is dependent on these huge datasets, and finely tuned to all the details and subtle underlying patterns, regardless of whether we are aware of them or not, and as we will see, getting objective, raw data sets of sufficient magnitude is rife with challenges.



Institutional bias


Similar to the concept of Conway’s Law, which states “any organization that designs a system (defined broadly) will produce a design whose structure is a copy of the organization's communication structure”, the data we naturally generate in action, conversation and interactions as a society or organisation will naturally reflect the values, beliefs and structure of the society (or organisation). There is an intrinsic and inescapable subjectivity in all big data, best described by Lisa Gitelman in her book Raw Data is an Oxymoron:

Objectivity is situated and historically specific; it comes from somewhere and is the result of ongoing changes to the conditions of inquiry, conditions that are at once material, social, and ethical

A simple example of this could be in criminal statistics: if a police force stop-and-search a particular demographic more heavily than others, then that will be reflected in the numbers and therefore that cultural subjectivity influences the data set - this subjectivity will then naturally carry over to, and likely be amplified by, the trained AI as it becomes finely tuned to the data (an example of this was seen where some software used to inform sentencing decisions relied on data that had institutional bias, which resulted in a racial bias in the risk assessment - strengthening the AI Now report’s proposal of banning AI use in these areas).



Finding complete & representative data


Compounding this problem is the fact that researchers working in AI face the challenge of finding datasets that are big enough and permitted for such use, which can be hard to come by, meaning they often make-do with incomplete or skewed datasets. For example, the popular community discussion web site Reddit makes its vast historic dataset publicly available, which is a rich source of natural text and conversation, and makes for a very tempting dataset for engineers and researchers to take advantage of - however, Reddit is a very specific subset of the internet, and the real world demographic, meaning that whilst there is undoubtedly a lot that can be learnt from that wealth of data, any AI trained on it will be heavily subjective.

There have been several reports finding that these incomplete or skewed data sets just further add to the bias. The 2017 AI Now report said:

data can easily privilege socioeconomically advantaged populations, those with greater access to connected devices and online services

Which is to be expected when you think about it really - always connected people with mobile devices will naturally be generating a lot more date than those without easy access to computers. On a very simple level, the core regular users of reddit, for example, will likely have access to mobile devices or in the very least have available access to computers and the internet - which rules out large parts of the population - not to mention the inclination to partake in the online community.

There are also other challenges that are intrinsic to the way AI currently works: if we have a dataset where a particular demographic is only reflected by 1% of the data, then the AI could claim to achieve 99% accuracy whilst being completely inaccurate for all of that 1% minority. Furthermore, we know that there is a strong relationship between the amount of training data and the accuracy of AI, so in the scenario we have a perfect representation of the population, by definition, all minority groups will have a smaller selection of data points to train on so inevitably the performance of the AI for minority groups will fare worse.

Finally let’s consider again that we have a huge, rich dataset (the idea scenario), and we try to intentionally exclude sensitive features that might explicitly encode bias: race, gender, age, etc. There are still loads of data points that may still act as a indirect proxy to these features, so even without including gender, age and sex in the input data, it is easy to see how these features can get encoded in other data points such as names, location, interests, communication style. This makes it even harder to detect and prevent bias in our datasets.

There is no objectivity in big data.



How can we address the problem?


Some of these examples might have clearer cases of existing bias that we need to be address in training our AI, but a tougher challenge is how can we address the more subtle biases hidden in the cultural objectivity that we might not even be aware of? We all carry our own opinions and biases that subconsciously affect our opinions and attitudes toward things - but if we are not consciously aware of those, we need to think about how we can ensure that developers training AI can have the foresight to engineer around these biases?

This issue highlights one often recommended  approach to tackling the problem of having a greater emphasis on the need for diversity in the teams building AI. Both diversity in terms of individual identities but also cross-functional teams. Statistically and broadly speaking, AI is often developed by teams of engineers with limited diversity, which results in a limited range of views when thinking about the dataset and in what goals are optimised for in the training process. The 2017 AI Now report recommended:

“​stakeholders​ ​in​ ​the​ ​AI​ ​field​ ​should release​ ​data​ ​on​ ​the​ ​participation​ ​of​ ​women,​ ​minorities​ ​and​ ​other​ ​marginalised​ ​groups within​ ​AI​ ​research​ ​and​ ​development.

Aside from trying to recognise subtle bias in the data, we also need to consider that the objective norm, and what we consider to be ok at the moment is changing. Going back to Lisa Gitelman’s quote: “Objectivity is situated and historically specific”. If you could get a dataset from even just two decades ago, it’s not hard to imagine that AI trained on that would have un-acceptable biases because the societal norm and general attitudes to race, gender and identity, etc have changed significantly since then.

As a simple example, take the motor insurance industry. For decades, insurance companies identified young male drivers as a particularly high risk of accident so traditionally charged much higher premiums for that demographic - previously a widely accepted approach, and one based in statistics: young male drivers were statistically more likely to have an accident behind the wheel. But then, in 2012 EU gender discrimination regulation came into effect that prevented companies charging men more than women, so now the insurers have stopped that categorisation for pricing despite the data being available. If that was AI it would need to be re-trained with a modified dataset, with gender probably removed from the data and thought put into other data points that would also need to be removed (names, for example, might very easily be a broad proxy to gender). Whilst this is a simpler example, as its a binary change in legislation with clear requirements, there are also the more gradual shifts in attitude where it becomes a lot fuzzier - like the changes in attitudes on race, gender and secuality over the last thirty years.

We previously discussed the idea that even if we exclude socially salient data points, such as gender, those features can still get encoded via other proxies in the data, and this example of the change in EU regulation and its effect on the insurance industry provides an interesting case study in exactly that phenomenon. There was an article written in the Guardian following the EU ruling, explaining that, despite the ruling meaning insurers couldn’t charge more because a driver was male, male premiums have actually increased in comparison to female premiums since. The reasoning they provide, is that rather than classifying on the crude, data point of gender, the system instead places greater importance on a wider set of data points, and it turns out that these other data points are really just acting as encoded proxies (they list car size, occupation, vehicle modifications). The article makes the observation that MoneySupermarket released a study showing that 8 out of the worst 10 occupations for drink/drug drive incidents were the building trade, with midwives being the least likely to have a drink/drug drive offence, the suggestion being that building trade is predominantly male, and midwives, predominantly female.



It certainly seems to me like there are still lots of challenges as to how we can foresee potential problems and how to tackle them. A key starting point will be ensuring teams working in the area have a good understanding of the dataset they are working with: where it comes from, any inherent bias or blind spots and which of the data points might need modifying or weighting due to their contextual/social salience. This will need to be driven through agreed best practices and AI development standards from organisations like AI Now and from academia, as well as a need for appropriate regulatory controls (although these face their own challenges, which I will discuss in a later article).

I also believe that these challenges mean an even greater need for for diversity of the teams -  both in terms of the race, background, gender etc of the team, and also cross-functional members, not just engineers but also working closely with the specific domain experts for the field.




Photo credits:

Heading Photo by Alex Knight on Unsplash
Anonymous person Photo by Andrew Worley on Unsplash

Conference updates: JAXLondon and W-JAX 2017

I have had the pleasure of closing out this year by speaking at two conferences. The first in October was JAXLondon, the second in November was W-JAX in Munich - the talk was titled "Agile Machine Learning: From theory to production".



As the title might suggest, the talk was about some considerations and challenges of doing Machine Learning(ML) work in a commercial environment - so a lot of softer aspects of ML: when you should adopt it, how you can work as a team on ML.

The talk at JAXLondon was recorded, so hopefully at some point that will be available for me to share as well, but for the time being, here is a brief interview that my co-speaker and I gave ahead of our JAXLondon talk.



See more details and write up here

Unsupervised Learning in Scala using Word2Vec

A pretty cool thing that has come out of recent Machine Learning advancements is the idea of "Word Embedding", specifically the advancements in the field made by Tomas Mikolov and his team at Google with the Word2Vec approach. Word Embedding is a language modelling approach that involves mapping words to vectors of numbers - If you imagine we are modelling every word in a given body of text to an N-dimension vector (it might be easier to visualise this as 2-dimensions - so each word is a pair of co-ordinates that can be plot on a graph), then that could be useful in plotting words and starting to understand relationships between words given their proximity. What's more, if we could map words to sets of numbers, then we could start thinking about interesting arithmetic that we could perform on the words.

Sounds cool, right? Now of course, the tricky bit is how can you convert a word to a vector of numbers in such a way that it encapsulates the details behind this relationship? And how can we do it without painstaking manual work and trying to somehow indicate semantic relationships and meaning in the words?


Unsupervised Learning

Word2Vec relies on neural networks and trains on a large, un-labelled piece of text in a technique known as "unsupervised" learning.

Contrary to the last neural network I discussed which was a "supervised" exercise (e.g. for every input record we had the expected output/answer), Word2Vec uses a completely "unsupervised" approach - in other words, the neural network simply takes a massive block of text with no markup or labels (broken into sentences or lines usually) and then uses that to train itself.

This kind of unsupervised learning can seem a little unbelievable at first, getting your head around the idea that a network could train itself without even knowing the "answers" seemed a little strange to me first time I heard the concept, especially as a fundamental requirement for a NN to converge on optimum solution requires a "cost-function" (e.g. some thing we can use after each feed-forward step to tell us how right we are, and if our NN is heading in the right direction).

But really, if we think back to the literal biological comparison with the brain, as people we learn through this unsupervised approach all the time - its basically trial-and-error.


It's child's play

Imagine a toddler attempting to learn to use a smart phone or tablet: they likely don't get shown explicitly to press an icon, or to swipe to unlock, but they might try combinations of power buttons, volume controls and swiping and seeing what happens (and if it does what they are ultimately trying to do), and they get feedback from the device - not direct feedback about what the correct gesture is, or how wrong they were, just the feedback that it doesn't do what they want - and if you have ever lived with a toddler who has got to grips with touchscreens, you may have noticed that when they then experience a TV or laptop, they instinctively attempt to touch or swipe the things on the screen that they want (in NN terms this would be known as "over fitting" - they have trained on too specific a set of data, so are poor at generalising - luckily, the introduction of a non-touch screen such as a TV expands their training set and they continue to improve their NN, getting better at generalising!)

So, this is basically how Word2Vec works. Which is pretty amazing if you think about it (well, I think its neat).


Word2Vec approaches

So how does this apply to Word2Vec? Well just like a smartphone gives implicit, in-direct feedback to a toddler, so the input data can provide feedback to itself. There are broadly two techniques when training the network:

Continuous Bag of Words (CBOW)

So, our NN has a large body of text broken up into sentences/lines - and just like in our last NN example, we take the first row from the training set, but we don't just take the whole sentence to push into the NN (after all, the sentence will be variable length, which would confuse our input neurons), instead we take a set number of words - referred to as the "window size", let's say 5, and feed those into the network. In this approach, the goal is for the NN to try and correctly guess the middle word in that window - that is, given a phrase of 5 words, the NN attempts to guess the word at position 3.

[It was ___ of those] days, not much to do

So its unsupervised learning, as we haven't had to go through any data and label things, or do any additional pre-processing - we can simply feed in any large body of text and it can just try to guess the words given their context.

Skip-gram

The Skip-gram approach is similar, but the inverse - that is, given the word at position n, it attempts to guess the words at position n-2, n-1, n+1, n+2.

[__ ___ one __ _____] days, not much to do

The network is trying to work out which word(s) are missing, and just looks to the data itself to see if it can guess it correctly.


Word2Vec with DeepLearning4J

So one popular deep-learning & word2vec implementation on the JVM is DeepLearning4J. It is pretty simple to use to get used to what is going on, and is pretty well documented (along with some good high-level overviews of some core topics). You can get up and running playing with the library and some example datasets pretty quickly following their guide. Their NN setup is also equally simple and worth playing with, their MNIST hello-world tutorial lets you get up and running with that dataset pretty quickly.

Food2Vec

A little while ago, I wrote a web crawler for the BBC food recipe archive, so I happened to have several thousand recipes sitting around and thought it might be fun to feed those recipes into Word2Vec to see if it could give any interesting results or if it was any good at recommending food pairings based on the semantic features the Word2Vec NN extracts from the data.

The first thing I tried was just using the ingredient list as a sentence - hoping that it would be better for extracting the relationship between ingredients, with each complete list of ingredients being input as a sentence.  My hope was that if I queried the trained model for X is to Beef, as Rosemary is to Lamb, I would start to get some interesting results - or at least be able to enter an ingredient and get similar ingredients to help identify possible substitutions.

As you can see, it has managed to extract some meaning from the data - for both pork and lamb, the nearest words do seem to be related to the target word, but not so much that could really be useful. Although this in itself is pretty exciting - it has taken an un-labelled body of text and has been able to learn some pretty accurate relationships between words.

Actually, on reflection, a list of ingredients isn't actually that great an input, as it isn't a natural structure and there is no natural ordering of the words - a lot of meaning is captured in the phrases rather than just lists of words.

So next up, I used the instructions for the recipes - each step in the recipe became a sentence for input, and minimal cleanup was needed, however, with some basic tweaking (it's fairly possible that if I played more with the Word2Vec configuration I could have got some improved results) the results weren't really that much better, and for the same lamb & pork search this was the output:

Again, its still impressive to see that some meaning has been found from these words, is it better than raw ingredient list? I think not - the pork one seems wrong, as it seems to have very much aligned pork as a poultry (although maybe that is some meaningful insight that conventional wisdom just hasn't taught us yet!?)

Arithmetic

Whilst this is pretty cool, there is further fun that can be had - in the form of simple arithmetic. A simple, often quoted example, is the case of countries and their capital cities - well trained Word2Vec models have countries and their capital cities equal distances apart:

(graph taken from DeepLearning4J Word2Vec intro)

So could we extract similar relationships between food stuffs?  The short answer, with the models trained so far, was kind of..

Word2Vec supports the idea of positive and negative matches when looking for nearest words - that allows you to find these kind of relationships. So what we are looking for is something like "X is to Lamb, as thigh is to chicken" (e.g. hopefully this should find a part of the lamb), and hopefully use this to extract further information about ingredient relationships that could be useful in thinking about food.

So, I ran that arithmetic against my two models.
The instructions based model returned the following output:

Which is a pretty good effort - I think if I had to name a lamb equivalent of chicken thigh, a lamb shank is probably what I would have gone for (top of the leg, both pieces of slow twitch muscle and both the more game-y, flavourful pieces of the animal - I will stop as we are getting into food-nerd territory).

I also ran the same query on the ingredients based set (which remember, ran better on the basic nearest words test):

Which interestingly, doesn't seem as good. It has the shin, which isn't too bad in as far as its the leg of the animals, but not quite as good a match as the previous.


Let us play

Once you have the input data, Word2Vec is super easy to get up and running. As always, the code is on GitHub if you want to see the build stuff (I did have to fudge some dependencies and exclude some stuff to get it running on Ubuntu - you may get errors relating to javacpp or jnind4j not available - but the build file has the required work arounds in to get that running), but the interesting bit is as follows:
If we run through what we are setting up here:

  1. Stop words - these are words we know we want to ignore -  I originally ruled these out as I didn't want measurements of ingredients to take too much meaning. 
  2. Line iterator and tokenizer - these are just core DL4J classes that will take care of processing the text line by line, word by word. This makes things much easier for us, so we don't have to worry about that stuff
  3. Min word frequency - this is the threshold for words to be interesting to us - if a word appears less than this number of times in the text then we don't include the mapping (as we aren't confident we have a strong enough signal for it)
  4. Iterations - how many training cycles are we going to loop for
  5. Layer size - this is the size of the vector that we will produce for each word - in this case we are saying we want to map each word to a 300 dimension vector, you can consider each vector a "feature" of the word that is being learnt, this is a part of the network that will really need to be tuned to each specific problem
  6. Seed - this is just used to "seed" the random numbers used in the network setup, setting this helps us get more repeatable results
  7. Window size - this is the number of words to use as input to our NN each time - relates to the CBOW/Skip-gram approaches described above.

And that's all you need to really get your first Word2Vec model up and running! So find some interesting data, load it in and start seeing what interesting stuff you can find.

So go have fun - try and find some interesting data sets of text stuff you can feed in and what you can work out about the relationships - and feel free to comment here with anything interesting you find.

Machine (re)learning: Neural Networks from scratch

Having recently changed roles, I am now in the enviable position of starting to do some work with machine learning (ML) tools. My undergraduate degree was actually in Artificial Intelligence, but that was over a decade ago, which is a long time in computer science in general, let alone the field of Machine Learning and AI which has progressed massively in the last few years.

So needless to say, coming back to it there has been a lot to learn. In the future I will write a bit more about some of the available tools and libraries that exist these days (both expanding on the traditional AI Stanford libraries I have mentioned previously with my tweet sentiment analysis, plus newer frameworks that cover "deep learning").

Anyway, inspired by this post, I thought it would be a fun Sunday night refresher to write my own neural network. The first, and last, time that I wrote a neural network was for my final year dissertation (and that code is long gone), so was writing from first principals. The rest of this post will be a very straight forward introduction to the ideas and the code for a basic single layer neural network with a simple sigmoid activation function. I am training and testing on a very simple labeled data set and with the current configuration it scores 90% on the un-seen classification tests.  I am by no means any kind of expert in this field, and there are plenty of resources and papers written by people who are inventing this stuff as we speak, but this will be more the musings of someone re-learning this stuff.

As always, all the code is on GitHub (and, as per my change in roles, this time it is all written in Scala).

A brief overview

The basic idea is that a Neural Network(NN) attempts to mimic the parallel architecture of the human brain. The human brain is a massive network of billions of simple neural cells that are interconnected, and given a stimulus they each either "fire" or don't, and its these firing neural cells and synapses that enable us to learn (excuse my crude explanation, I'm clearly not a biologist..).

So that is what we are trying to build: A connected network of neurons that given some stimuli either "fire" or don't.

A Neuron

Ok, so this seems like a sensible place to start, right? We all know what a network is, so what are these nodes in our network that we are connecting? These are the decision points, and in themselves are incredibly simple. They are basically a function that takes n input values  and multiplies them by a pre-defined weight (per input), adds a bias and then runs it through an activation function (think of this as our fire/don't-fire function):



In terms of our code, this is pretty simple (don't worry about the sigmoid derivative values we are setting, we are just doing this to save time later):

As you can see, the neuron holds the state about the weights of the different inputs (a NN is normally fixed in terms of number of neurons, so once initialised at the start we know that we will get the same number of inputs).

You may have also noticed the first line of the class, where we require an ActivationFunction to be applied - as I mentioned, this is our final processing of the output. In this example I am just using the Sigmoid function:

As you can see, it's a pretty simple function. Much like the brain, the neurons are very simple processors, and its only the combination of them that makes the brain so powerful (or deep learning, for that matter)

The network

Ok, so a sinple "shallow" NN has an input layer, a single hidden layer and an output layer (deep NNs will have multiple hidden layers).  The network is fully connected between layers - that is to say, each node is connected to every node in the next layer up, and thats it - no neurons are connected to other layers and no connections are missed.



The exact setup of the network is largely problem dependent, but there are some observations:

  1. - The input layer has to correlate to your inputs: If you have a data set that has two input values, let's say you have a dataset that contains house price data, and you have the input values number of rooms and square foot then you would have to have two input neurons.

  2. - Similarly, if you were using the NN for a classification problem and you know you have a fixed number of classifications, the output neurons would correlate to that. For example, consider you were using the NN to recognise hand written digits (0-9) then you would likely have 10 output neurons to group those possible outputs.

 The number of hidden neurons, or the number of layers is a lot more problem dependent and is something that needs to be tuned per problem.

If you have ever worked with graph type structures in code before, then setting up a simple network of these neurons is also relatively straight forward given the uniform structure of the network.

The weights

Oh right, yes. So all these neurons are connected, but its a weighted graph, in CS terms. That is, the connection between each node is assigned a weight - as a simple example, if we take the house price dataset as an example, after looking at the data we might determine that the number of rooms is a more significant factor in the end result, so that connection should have a greater weighting than the other input.  Now that is not exactly what happens, but in simple terms it's a good way of thinking about it.

Given that the end goal is for the NN to learn from the data, it really doesn't matter too much what we initialise the weights for all the connections to, so at start-up we can just assign random values. Just think of our NN as a newborn baby - it has no idea how important different inputs are, or how important the different neural cells that fire in response to the stimuli are - everything is just firing off all over the place as they slowly start to learn!

So what next?

Ok, so we have our super-simple neurons, that just mimic a single brain cell and we have them connected in a structured but randomly weighted fashion, so how does the network learn? Well, this is the interesting bit..

When it comes to training the network we need a pretty large dataset to allow it to be able to learn enough to start generalising - but at the same time, we don't want to train it on all the data, as we want to hold some back to test it at the end, just to see just how smart it really is.

In AI terms, we are going to be performing "supervised" learning - this just means that we will train it with a dataset where we know the correct answer, so we can make adjustments based on how well (or badly) the network is doing - this is different to "unsupervised" learning, where we have lots of data, but we don't have the right answer for each data point.

Training: Feed forward


The first step is the "feed-forward" step - this is where we grab the first record from our training data set and feed the inputs into our randomly initialised NN, that input is fed through the all of the neurons in the NN until we get to the output layer and we have the networks attempt at the answer. As the weighting is all random, this is going to be way off (imagine a toddlers guess the first time they ever see a smart phone).

As you can see, the code is really simple, just iterate through the layers calculating the output for each neuron. Now, as this is supervised learning we also have the expected output for the dataset, and this means we can work out how far off the network is and attempt to adjust the weights in the network so that next time it performs a bit better.

Training: Back propogation

This is where the magic, and the maths, comes in.  The TL;DR overview for this is basically, we know how wrong the network was in its guess, so we attempt to update each of the connection weights based on that error, and to do so we use differentiation to work out the direction to go to minimise the error.

This took a while of me staring at equations and racking my brain trying to remember my A-level maths stuff to get this, and even so, I wouldn't like to have to go back and attempt to explain this to my old maths teacher, but I'm happy I have enough of a grasp of what is going on, having now coded it up.

In a very broad, rough overview, here is what we are going to do now:

  1. - Cacluate the squared error of the outputs. That is a fairly simple equation of
    1/2 * (target - output)^2

  2. - From there, we will work out the derivative of this function. A lot of the errors we will start to work with now uses differentiation and the derivatives of the result - and this takes a little bit of calculus, but the basic reason is relatively straight forward if you think about what differentiation is for.

     
  3. - We work back through the network, calculating the errors for every neuron combining the error, derivatives and the weighting (to determine how much a particular connection played in an error, it also needs to be considered when correcting the weighting.

Once we have adjusted the weight for each connection based on the weight, we start again and run the next training data record.  Given the variation in the dataset, the next training record might be quite different to the previous record trained on, so the adjustment might be quite different, so after lots (maybe millions) of iterations over the data, incrementally tweaking and adjusting the weights for the different cases, hopefully it starts to converge on something like a reasonable performance.

Maths fun: Why the derivative of the errors?

So, why is calculating the derivative relevant in adjusting the errors?

If you imagine the network as a function of all the different weights, its pretty complicated, but if we were to reduce this, for the sake of easier visualisation, to just be a 3-d space of possible points (e.g. we have just 3 weights to adjust, and those weights are plotted on a 3-d graph) - now imagine our function plots a graph something like this:


(taken from the wikipedia page on partial derivatives)

The derivative of the function allows us to work out the direction of the slope in the graph from a given point, so with our three current weights (co-ordinates) we can use the derivative to work out the direction in which we need to adjust these weights.

Conclusion

The whole NN in scala is currently on Github, and there isn't much more code than I have included here.  I have found it pretty helpful to code it up from scratch and actually have to think about it - and has been fun coding it up and seeing it train up to getting 90% accuracy on the unseen data set I was using (just a simple two-in two-out dataset, so not that impressive).

Next up I will see how the network performs against the MNIST dataset (like the Hello-World benchmark for machine learning, and is the classification of handwritten digits).

Oh, and if you are interested as to why the images look like a dodgy photocopy, they are photos of original diagrams that I included in my final year dissertation in university, just for old time sake!

Google's AI Challenge

It is that time of year again that Google is sponsoring an AI challenge - It's a standard affair, you have to controll a set of ants and collect food and eventually defeat other ant hills, but its good fun and always a good opportunity to play around programming in a more competitive environment.

Full details are here, you can download the starter pack and get started in lots of programming languages, but unsuprisingly I have gone with Java for now - once you have the pack downloaded, they have some steps they recommend to get your ants doing the basics well, although you will obviously want to go beyond this minimum set, its a good chance to get your head around the object model and the APIs that are provided as standard.

I have just started this today and have completed the first 4 steps of the walkthrough, so not even started thinking about personal tactics and strategy.. so expect some posts to follow with those.

In the meantime, I am, as always, keeping my code on GitHub, so feel free to have a poke around how I have implemented the various steps in my code (only the first step of the tutorial is implemeneted in Java at the moment, the others are all currently in Python, so does require the slightest bit of imagination)