Showing posts with label Machine learning. Show all posts

JAXLondon 2017: Agile Machine Learning [VIDEO]

Last October a colleague and I gave a talk at the JAXLondon Conference about Machine Learning in an agile, commercial environment (I then also gave the talk again in November in Munich at the W-JAX Conference).




The video of the talk is now available - the first half (and end section) is mostly softer stuff, where I talk about lessons learnt from doing ML research in a commercial environment and the middle section is my colleague, Sumanas, talking about how Word2Vec works and some more interesting demos of using it in an interesting application!




In related conference speaking news, I will be in Berlin next month for the API Conference to talk about Product Management and API Design: https://apiconference.net/api-design-documentation/your-api-as-a-product-thinking-like-a-product-manager/



We need to talk about AI

Ethical and Regulatory questions facing AI





Regardless of area of expertise, most of us are probably already aware of the momentum around Artificial Intelligence (AI). Between self driving cars, home assistants (Alexa, Google Home, et al) and the growing capabilities of our mobile devices there is no escaping the ever looming presence of AI in our lives.

Furthermore, it seems unlikely that this will slow down anytime soon. A recent Narrative Science study found that AI adoption grew by 60% in the last year with 61% of organisations having reported to have implemented AI within their business, and a Gartner report predicted that by 2020 85% of customer interactions will be managed without human intervention.

But despite this growth, there is still a question mark over whether, and if so, how, the field should be regulated. Having been brought up on decades of sci-fi about AI going rogue and robots enslaving the human race, it feels like there is both the fear of this possible future, whilst also scepticism that these fears are only the stuff of movies. Elon Musk has famously warned of the future risks of AI: “I think we should be very careful about artificial intelligence. If I had to guess at what our biggest existential threat is, it’s probably that” whilst others, including Mark Zuckerberg, have downplayed the claims of doomsday scenarios as irresponsible.

So what's the big deal? AI already permeates so many aspects of life and business, but considering for a moment that these technologies could be being used to control autonomous cars on public roads, determine people’s credit score or suitability for a job, to detect illness or even in policing and judicial decision making - it is pretty clear that we should have a good understanding of these technologies and clear systems of accountability and control in place. In all these examples getting a decision wrong has the potential to ruin lives, yet there is still limited regulation, control or even understanding of the algorithms, the data and their usage.

A common analogy is with other heavily regulated industries: big pharma companies can’t release drugs without thorough testing and approval, yet several big tech companies have already started testing autonomous vehicles on public roads with limited regulatory controls (that’s not to say that they have had a completely free pass, there are varying levels of regulation, depending on the region. Arizona has long been promoting itself as an AI friendly state to try to attract business from big tech, making it as easy as possible for companies to test self driving cars with minimal regulatory friction, and they recently saw the first fatality from a self-driving car).

In its 2017 report, the AI Now Institute recommended that AI be outright ban from use in any high risk areas, such as criminal justice, healthcare, welfare and education and further measures for other domains - which given the potential impact of errors in these domains, seems like a fairly sensible starting point.



Uncertainty and the unknown


One key aspect that is especially troubling is the lack of understanding of both the data and the underlying technology. This isn’t necessarily a surprise - we have computers being trained on millions of data points, to the point of being able to outperform humans at their tasks, so it should come as no surprise that both the inner workings and the end results could be beyond easy comprehension.

This problem has been demonstrated by several high profile mishaps from large tech companies, showing that even companies that have a wealth of resources and technical expertise in the domain can be caught out - such as Microsoft’s AI chatbot Tay, who quickly became racist when released into the wild. Clearly Microsoft had neither intended nor envisaged that end result. Similarly, when Google translate revealed gender bias in pairing “he” with “hardworking” and “she” with “lazy” - it clearly wasn’t an intentional or foreseen behaviour, but eventually revealed itself with wider usage.



Understanding where bias in AI comes from


To get a better understanding of where these biases and blind spots come from, let’s take a look at how AI learns. Broadly speaking, there are three primary approaches to training AI: Supervised, Unsupervised and Reinforcement.

Unsupervised learning is where the AI is fed very large amounts of raw data - for example an entire corpus of fictional texts - and it is left to work out patterns or groupings. That is, it doesn’t know a right or wrong answer, but can identify related things from the dataset and group them together (for example, AI reading popular fiction might group together terms such as “batman” and “wonder woman”, but it would have no knowledge of what these terms actually mean).

Supervised learning is where the AI is fed very large amounts of marked up data - that is, for each input, it also gets passed the expected output. An example of this is if you had a large set of photos (say Google Photos) which are pre-tagged with descriptions of what is in the photo, the dataset could be used to train an AI to identify contents of a photo.

Reinforcement learning is similar to supervised in as much as the algorithm gets information as to whether or not it is performing well (like knowing the answer for a given input) but is in the form of a feedback loop and works more like a trial-and-error approach to learning (it might have a general fitness score function that can be used by the algorithm to determine whether or not its response to given input has been successful or not and adjust its response for the next cycle). The simplest example of this is something like AlphaGo/AlphaZero, where an algorithm learns to play a game like Go or chess by trial and error and gets feedback on its attempted response from the game itself.

Both Supervised and Unsupervised learning cases require vast amounts of data to accurately train AI, which really leads us to one of the primary challenges for building fair and ethical AI: sourcing the data to train on. AI is dependent on these huge datasets, and finely tuned to all the details and subtle underlying patterns, regardless of whether we are aware of them or not, and as we will see, getting objective, raw data sets of sufficient magnitude is rife with challenges.



Institutional bias


Similar to the concept of Conway’s Law, which states “any organization that designs a system (defined broadly) will produce a design whose structure is a copy of the organization's communication structure”, the data we naturally generate in action, conversation and interactions as a society or organisation will naturally reflect the values, beliefs and structure of the society (or organisation). There is an intrinsic and inescapable subjectivity in all big data, best described by Lisa Gitelman in her book Raw Data is an Oxymoron:

Objectivity is situated and historically specific; it comes from somewhere and is the result of ongoing changes to the conditions of inquiry, conditions that are at once material, social, and ethical

A simple example of this could be in criminal statistics: if a police force stop-and-search a particular demographic more heavily than others, then that will be reflected in the numbers and therefore that cultural subjectivity influences the data set - this subjectivity will then naturally carry over to, and likely be amplified by, the trained AI as it becomes finely tuned to the data (an example of this was seen where some software used to inform sentencing decisions relied on data that had institutional bias, which resulted in a racial bias in the risk assessment - strengthening the AI Now report’s proposal of banning AI use in these areas).



Finding complete & representative data


Compounding this problem is the fact that researchers working in AI face the challenge of finding datasets that are big enough and permitted for such use, which can be hard to come by, meaning they often make-do with incomplete or skewed datasets. For example, the popular community discussion web site Reddit makes its vast historic dataset publicly available, which is a rich source of natural text and conversation, and makes for a very tempting dataset for engineers and researchers to take advantage of - however, Reddit is a very specific subset of the internet, and the real world demographic, meaning that whilst there is undoubtedly a lot that can be learnt from that wealth of data, any AI trained on it will be heavily subjective.

There have been several reports finding that these incomplete or skewed data sets just further add to the bias. The 2017 AI Now report said:

data can easily privilege socioeconomically advantaged populations, those with greater access to connected devices and online services

Which is to be expected when you think about it really - always connected people with mobile devices will naturally be generating a lot more date than those without easy access to computers. On a very simple level, the core regular users of reddit, for example, will likely have access to mobile devices or in the very least have available access to computers and the internet - which rules out large parts of the population - not to mention the inclination to partake in the online community.

There are also other challenges that are intrinsic to the way AI currently works: if we have a dataset where a particular demographic is only reflected by 1% of the data, then the AI could claim to achieve 99% accuracy whilst being completely inaccurate for all of that 1% minority. Furthermore, we know that there is a strong relationship between the amount of training data and the accuracy of AI, so in the scenario we have a perfect representation of the population, by definition, all minority groups will have a smaller selection of data points to train on so inevitably the performance of the AI for minority groups will fare worse.

Finally let’s consider again that we have a huge, rich dataset (the idea scenario), and we try to intentionally exclude sensitive features that might explicitly encode bias: race, gender, age, etc. There are still loads of data points that may still act as a indirect proxy to these features, so even without including gender, age and sex in the input data, it is easy to see how these features can get encoded in other data points such as names, location, interests, communication style. This makes it even harder to detect and prevent bias in our datasets.

There is no objectivity in big data.



How can we address the problem?


Some of these examples might have clearer cases of existing bias that we need to be address in training our AI, but a tougher challenge is how can we address the more subtle biases hidden in the cultural objectivity that we might not even be aware of? We all carry our own opinions and biases that subconsciously affect our opinions and attitudes toward things - but if we are not consciously aware of those, we need to think about how we can ensure that developers training AI can have the foresight to engineer around these biases?

This issue highlights one often recommended  approach to tackling the problem of having a greater emphasis on the need for diversity in the teams building AI. Both diversity in terms of individual identities but also cross-functional teams. Statistically and broadly speaking, AI is often developed by teams of engineers with limited diversity, which results in a limited range of views when thinking about the dataset and in what goals are optimised for in the training process. The 2017 AI Now report recommended:

“​stakeholders​ ​in​ ​the​ ​AI​ ​field​ ​should release​ ​data​ ​on​ ​the​ ​participation​ ​of​ ​women,​ ​minorities​ ​and​ ​other​ ​marginalised​ ​groups within​ ​AI​ ​research​ ​and​ ​development.

Aside from trying to recognise subtle bias in the data, we also need to consider that the objective norm, and what we consider to be ok at the moment is changing. Going back to Lisa Gitelman’s quote: “Objectivity is situated and historically specific”. If you could get a dataset from even just two decades ago, it’s not hard to imagine that AI trained on that would have un-acceptable biases because the societal norm and general attitudes to race, gender and identity, etc have changed significantly since then.

As a simple example, take the motor insurance industry. For decades, insurance companies identified young male drivers as a particularly high risk of accident so traditionally charged much higher premiums for that demographic - previously a widely accepted approach, and one based in statistics: young male drivers were statistically more likely to have an accident behind the wheel. But then, in 2012 EU gender discrimination regulation came into effect that prevented companies charging men more than women, so now the insurers have stopped that categorisation for pricing despite the data being available. If that was AI it would need to be re-trained with a modified dataset, with gender probably removed from the data and thought put into other data points that would also need to be removed (names, for example, might very easily be a broad proxy to gender). Whilst this is a simpler example, as its a binary change in legislation with clear requirements, there are also the more gradual shifts in attitude where it becomes a lot fuzzier - like the changes in attitudes on race, gender and secuality over the last thirty years.

We previously discussed the idea that even if we exclude socially salient data points, such as gender, those features can still get encoded via other proxies in the data, and this example of the change in EU regulation and its effect on the insurance industry provides an interesting case study in exactly that phenomenon. There was an article written in the Guardian following the EU ruling, explaining that, despite the ruling meaning insurers couldn’t charge more because a driver was male, male premiums have actually increased in comparison to female premiums since. The reasoning they provide, is that rather than classifying on the crude, data point of gender, the system instead places greater importance on a wider set of data points, and it turns out that these other data points are really just acting as encoded proxies (they list car size, occupation, vehicle modifications). The article makes the observation that MoneySupermarket released a study showing that 8 out of the worst 10 occupations for drink/drug drive incidents were the building trade, with midwives being the least likely to have a drink/drug drive offence, the suggestion being that building trade is predominantly male, and midwives, predominantly female.



It certainly seems to me like there are still lots of challenges as to how we can foresee potential problems and how to tackle them. A key starting point will be ensuring teams working in the area have a good understanding of the dataset they are working with: where it comes from, any inherent bias or blind spots and which of the data points might need modifying or weighting due to their contextual/social salience. This will need to be driven through agreed best practices and AI development standards from organisations like AI Now and from academia, as well as a need for appropriate regulatory controls (although these face their own challenges, which I will discuss in a later article).

I also believe that these challenges mean an even greater need for for diversity of the teams -  both in terms of the race, background, gender etc of the team, and also cross-functional members, not just engineers but also working closely with the specific domain experts for the field.




Photo credits:

Heading Photo by Alex Knight on Unsplash
Anonymous person Photo by Andrew Worley on Unsplash

Conference updates: JAXLondon and W-JAX 2017

I have had the pleasure of closing out this year by speaking at two conferences. The first in October was JAXLondon, the second in November was W-JAX in Munich - the talk was titled "Agile Machine Learning: From theory to production".



As the title might suggest, the talk was about some considerations and challenges of doing Machine Learning(ML) work in a commercial environment - so a lot of softer aspects of ML: when you should adopt it, how you can work as a team on ML.

The talk at JAXLondon was recorded, so hopefully at some point that will be available for me to share as well, but for the time being, here is a brief interview that my co-speaker and I gave ahead of our JAXLondon talk.



See more details and write up here

Agile Machine Learning: From Theory to Production

Later this year, Sumanas and I will be co-presenting a talk about researching Machine Learning in an agile development environment at the JAXLondon conference. This is a high level overview of some of the topics we will be presenting (we will also try to get some cool ML demos in there too, just to make things a bit more interesting).

So what’s the problem?

Artificial Intelligence(AI) and Machine Learning(ML) are all the rage right now - at Google’s recent I/O 2017 event they really put ML front and center with plans to bake it into all their products, and lots of other large companies are aligning themselves as Machine Learning companies as a natural progression from Big Data.

According to a recent Narrative Science survey, 38% of enterprises surveyed were already using AI, with 62% expecting to be using it by 2018. So it would be understandable if many companies might be feeling the pressure to invest in an AI strategy, before fully understanding what they are aiming to achieve, let alone how it might fit into a traditional engineering delivery team.

We have been working the last 12 months on taking a new product to market and trying to go from a simple idea to a production ML system. Along the way we have had to integrate open ended academic research tasks with our existing agile development process and project planning, as well as working out how to deliver the ML system to a production setting in a repeatable, robust way, with all the considerations expected from a normal software project.

Here are a few things you might consider if you are planning your ML roadmap (and topics we will cover in more detail in the JAXLondon session in October)

Machine Learning != your product

Machine Learning is a powerful tool to enhance a product - whether it be by reducing costs of human curation, or more powerful understanding for your voice/natural language interface - however, Machine Learning shouldn’t be considered the selling point of the product - think of the end result Product First, that is, is there a market for the product regardless of whether it is powered by ML or human oversight. Consider if it makes sense to build a fully non-ML version of the product to start proving the market fit and delivering value to customers.

Start small and build

The Lean Startup principles of MVP and fast iterations still apply here: following on from the point above, starting with a non-ML product, as you start to apply ML techniques, if you are able to start leveraging some of these techniques to get even a small increase in performance (better recommendations, reduced human effort/cost, improved user experience - replacing human process with ML for just 5% of cases can start to realise cost benefits) then that can start adding value straight away. From a small start you are able to start proving the value that can be added whilst also getting the ML infrastructure tested and proven.

Tie into development sprint cycles

You may be hiring a new R&D team, or you may be using members from your existing engineering team, either way, it is helpful to have them working in similar development sprint cycles (if you work in sprints). It will allow both sides of the teams to understand what is happening and how work is progressing - product and engineering changes and issues might be useful in informing the direction of R&D and likewise, there may be data features or feedback from the R&D team that could be easily engineered and would make things simpler for research. Whilst research is ongoing, and often a time consuming task, having fortnightly (or whatever the sprint length is) checkpoints where ideas can be discussed and demoed can be good for the whole team’s understanding as well as being a positive motivator.

Don’t forget Clean Code!

Whilst experimenting and researching different ideas it can be pretty easy to fall into hacking mode, just rattling out rough scripts to prove an initial concept or idea - and there is definitely a place for this, but as your team progresses it will be more beneficial to actually invest in good coding principles. Whilst one-off scripts make sense to be hacked out, as the team works across several ideas, having code that is re-usable and organised sensibly with proper separation of concerns can make the research in the future easier, as well as reducing the cost when it comes to the production-isation. Investing in some machinery to make experiments easily testable (and benchmarking different solutions) will be very beneficial to invest in from the start.

Recommended Reading

While the interwebs are awash with Machine Learning articles, tutorials and click-baitey links guaranteed to reduce the error of your model in 3 quick steps, the following is a small list of resources that we think are worth browsing.



For the academically inclined, the following is a list of papers, both recent and not so recent:




During the session in JAXLondon later this year, we will go into more detail with these ideas, as well as others including technical and architectural considerations for building and deploying an ML stack.

Unsupervised Learning in Scala using Word2Vec

A pretty cool thing that has come out of recent Machine Learning advancements is the idea of "Word Embedding", specifically the advancements in the field made by Tomas Mikolov and his team at Google with the Word2Vec approach. Word Embedding is a language modelling approach that involves mapping words to vectors of numbers - If you imagine we are modelling every word in a given body of text to an N-dimension vector (it might be easier to visualise this as 2-dimensions - so each word is a pair of co-ordinates that can be plot on a graph), then that could be useful in plotting words and starting to understand relationships between words given their proximity. What's more, if we could map words to sets of numbers, then we could start thinking about interesting arithmetic that we could perform on the words.

Sounds cool, right? Now of course, the tricky bit is how can you convert a word to a vector of numbers in such a way that it encapsulates the details behind this relationship? And how can we do it without painstaking manual work and trying to somehow indicate semantic relationships and meaning in the words?


Unsupervised Learning

Word2Vec relies on neural networks and trains on a large, un-labelled piece of text in a technique known as "unsupervised" learning.

Contrary to the last neural network I discussed which was a "supervised" exercise (e.g. for every input record we had the expected output/answer), Word2Vec uses a completely "unsupervised" approach - in other words, the neural network simply takes a massive block of text with no markup or labels (broken into sentences or lines usually) and then uses that to train itself.

This kind of unsupervised learning can seem a little unbelievable at first, getting your head around the idea that a network could train itself without even knowing the "answers" seemed a little strange to me first time I heard the concept, especially as a fundamental requirement for a NN to converge on optimum solution requires a "cost-function" (e.g. some thing we can use after each feed-forward step to tell us how right we are, and if our NN is heading in the right direction).

But really, if we think back to the literal biological comparison with the brain, as people we learn through this unsupervised approach all the time - its basically trial-and-error.


It's child's play

Imagine a toddler attempting to learn to use a smart phone or tablet: they likely don't get shown explicitly to press an icon, or to swipe to unlock, but they might try combinations of power buttons, volume controls and swiping and seeing what happens (and if it does what they are ultimately trying to do), and they get feedback from the device - not direct feedback about what the correct gesture is, or how wrong they were, just the feedback that it doesn't do what they want - and if you have ever lived with a toddler who has got to grips with touchscreens, you may have noticed that when they then experience a TV or laptop, they instinctively attempt to touch or swipe the things on the screen that they want (in NN terms this would be known as "over fitting" - they have trained on too specific a set of data, so are poor at generalising - luckily, the introduction of a non-touch screen such as a TV expands their training set and they continue to improve their NN, getting better at generalising!)

So, this is basically how Word2Vec works. Which is pretty amazing if you think about it (well, I think its neat).


Word2Vec approaches

So how does this apply to Word2Vec? Well just like a smartphone gives implicit, in-direct feedback to a toddler, so the input data can provide feedback to itself. There are broadly two techniques when training the network:

Continuous Bag of Words (CBOW)

So, our NN has a large body of text broken up into sentences/lines - and just like in our last NN example, we take the first row from the training set, but we don't just take the whole sentence to push into the NN (after all, the sentence will be variable length, which would confuse our input neurons), instead we take a set number of words - referred to as the "window size", let's say 5, and feed those into the network. In this approach, the goal is for the NN to try and correctly guess the middle word in that window - that is, given a phrase of 5 words, the NN attempts to guess the word at position 3.

[It was ___ of those] days, not much to do

So its unsupervised learning, as we haven't had to go through any data and label things, or do any additional pre-processing - we can simply feed in any large body of text and it can just try to guess the words given their context.

Skip-gram

The Skip-gram approach is similar, but the inverse - that is, given the word at position n, it attempts to guess the words at position n-2, n-1, n+1, n+2.

[__ ___ one __ _____] days, not much to do

The network is trying to work out which word(s) are missing, and just looks to the data itself to see if it can guess it correctly.


Word2Vec with DeepLearning4J

So one popular deep-learning & word2vec implementation on the JVM is DeepLearning4J. It is pretty simple to use to get used to what is going on, and is pretty well documented (along with some good high-level overviews of some core topics). You can get up and running playing with the library and some example datasets pretty quickly following their guide. Their NN setup is also equally simple and worth playing with, their MNIST hello-world tutorial lets you get up and running with that dataset pretty quickly.

Food2Vec

A little while ago, I wrote a web crawler for the BBC food recipe archive, so I happened to have several thousand recipes sitting around and thought it might be fun to feed those recipes into Word2Vec to see if it could give any interesting results or if it was any good at recommending food pairings based on the semantic features the Word2Vec NN extracts from the data.

The first thing I tried was just using the ingredient list as a sentence - hoping that it would be better for extracting the relationship between ingredients, with each complete list of ingredients being input as a sentence.  My hope was that if I queried the trained model for X is to Beef, as Rosemary is to Lamb, I would start to get some interesting results - or at least be able to enter an ingredient and get similar ingredients to help identify possible substitutions.

As you can see, it has managed to extract some meaning from the data - for both pork and lamb, the nearest words do seem to be related to the target word, but not so much that could really be useful. Although this in itself is pretty exciting - it has taken an un-labelled body of text and has been able to learn some pretty accurate relationships between words.

Actually, on reflection, a list of ingredients isn't actually that great an input, as it isn't a natural structure and there is no natural ordering of the words - a lot of meaning is captured in the phrases rather than just lists of words.

So next up, I used the instructions for the recipes - each step in the recipe became a sentence for input, and minimal cleanup was needed, however, with some basic tweaking (it's fairly possible that if I played more with the Word2Vec configuration I could have got some improved results) the results weren't really that much better, and for the same lamb & pork search this was the output:

Again, its still impressive to see that some meaning has been found from these words, is it better than raw ingredient list? I think not - the pork one seems wrong, as it seems to have very much aligned pork as a poultry (although maybe that is some meaningful insight that conventional wisdom just hasn't taught us yet!?)

Arithmetic

Whilst this is pretty cool, there is further fun that can be had - in the form of simple arithmetic. A simple, often quoted example, is the case of countries and their capital cities - well trained Word2Vec models have countries and their capital cities equal distances apart:

(graph taken from DeepLearning4J Word2Vec intro)

So could we extract similar relationships between food stuffs?  The short answer, with the models trained so far, was kind of..

Word2Vec supports the idea of positive and negative matches when looking for nearest words - that allows you to find these kind of relationships. So what we are looking for is something like "X is to Lamb, as thigh is to chicken" (e.g. hopefully this should find a part of the lamb), and hopefully use this to extract further information about ingredient relationships that could be useful in thinking about food.

So, I ran that arithmetic against my two models.
The instructions based model returned the following output:

Which is a pretty good effort - I think if I had to name a lamb equivalent of chicken thigh, a lamb shank is probably what I would have gone for (top of the leg, both pieces of slow twitch muscle and both the more game-y, flavourful pieces of the animal - I will stop as we are getting into food-nerd territory).

I also ran the same query on the ingredients based set (which remember, ran better on the basic nearest words test):

Which interestingly, doesn't seem as good. It has the shin, which isn't too bad in as far as its the leg of the animals, but not quite as good a match as the previous.


Let us play

Once you have the input data, Word2Vec is super easy to get up and running. As always, the code is on GitHub if you want to see the build stuff (I did have to fudge some dependencies and exclude some stuff to get it running on Ubuntu - you may get errors relating to javacpp or jnind4j not available - but the build file has the required work arounds in to get that running), but the interesting bit is as follows:
If we run through what we are setting up here:

  1. Stop words - these are words we know we want to ignore -  I originally ruled these out as I didn't want measurements of ingredients to take too much meaning. 
  2. Line iterator and tokenizer - these are just core DL4J classes that will take care of processing the text line by line, word by word. This makes things much easier for us, so we don't have to worry about that stuff
  3. Min word frequency - this is the threshold for words to be interesting to us - if a word appears less than this number of times in the text then we don't include the mapping (as we aren't confident we have a strong enough signal for it)
  4. Iterations - how many training cycles are we going to loop for
  5. Layer size - this is the size of the vector that we will produce for each word - in this case we are saying we want to map each word to a 300 dimension vector, you can consider each vector a "feature" of the word that is being learnt, this is a part of the network that will really need to be tuned to each specific problem
  6. Seed - this is just used to "seed" the random numbers used in the network setup, setting this helps us get more repeatable results
  7. Window size - this is the number of words to use as input to our NN each time - relates to the CBOW/Skip-gram approaches described above.

And that's all you need to really get your first Word2Vec model up and running! So find some interesting data, load it in and start seeing what interesting stuff you can find.

So go have fun - try and find some interesting data sets of text stuff you can feed in and what you can work out about the relationships - and feel free to comment here with anything interesting you find.

Machine (re)learning: Neural Networks from scratch

Having recently changed roles, I am now in the enviable position of starting to do some work with machine learning (ML) tools. My undergraduate degree was actually in Artificial Intelligence, but that was over a decade ago, which is a long time in computer science in general, let alone the field of Machine Learning and AI which has progressed massively in the last few years.

So needless to say, coming back to it there has been a lot to learn. In the future I will write a bit more about some of the available tools and libraries that exist these days (both expanding on the traditional AI Stanford libraries I have mentioned previously with my tweet sentiment analysis, plus newer frameworks that cover "deep learning").

Anyway, inspired by this post, I thought it would be a fun Sunday night refresher to write my own neural network. The first, and last, time that I wrote a neural network was for my final year dissertation (and that code is long gone), so was writing from first principals. The rest of this post will be a very straight forward introduction to the ideas and the code for a basic single layer neural network with a simple sigmoid activation function. I am training and testing on a very simple labeled data set and with the current configuration it scores 90% on the un-seen classification tests.  I am by no means any kind of expert in this field, and there are plenty of resources and papers written by people who are inventing this stuff as we speak, but this will be more the musings of someone re-learning this stuff.

As always, all the code is on GitHub (and, as per my change in roles, this time it is all written in Scala).

A brief overview

The basic idea is that a Neural Network(NN) attempts to mimic the parallel architecture of the human brain. The human brain is a massive network of billions of simple neural cells that are interconnected, and given a stimulus they each either "fire" or don't, and its these firing neural cells and synapses that enable us to learn (excuse my crude explanation, I'm clearly not a biologist..).

So that is what we are trying to build: A connected network of neurons that given some stimuli either "fire" or don't.

A Neuron

Ok, so this seems like a sensible place to start, right? We all know what a network is, so what are these nodes in our network that we are connecting? These are the decision points, and in themselves are incredibly simple. They are basically a function that takes n input values  and multiplies them by a pre-defined weight (per input), adds a bias and then runs it through an activation function (think of this as our fire/don't-fire function):



In terms of our code, this is pretty simple (don't worry about the sigmoid derivative values we are setting, we are just doing this to save time later):

As you can see, the neuron holds the state about the weights of the different inputs (a NN is normally fixed in terms of number of neurons, so once initialised at the start we know that we will get the same number of inputs).

You may have also noticed the first line of the class, where we require an ActivationFunction to be applied - as I mentioned, this is our final processing of the output. In this example I am just using the Sigmoid function:

As you can see, it's a pretty simple function. Much like the brain, the neurons are very simple processors, and its only the combination of them that makes the brain so powerful (or deep learning, for that matter)

The network

Ok, so a sinple "shallow" NN has an input layer, a single hidden layer and an output layer (deep NNs will have multiple hidden layers).  The network is fully connected between layers - that is to say, each node is connected to every node in the next layer up, and thats it - no neurons are connected to other layers and no connections are missed.



The exact setup of the network is largely problem dependent, but there are some observations:

  1. - The input layer has to correlate to your inputs: If you have a data set that has two input values, let's say you have a dataset that contains house price data, and you have the input values number of rooms and square foot then you would have to have two input neurons.

  2. - Similarly, if you were using the NN for a classification problem and you know you have a fixed number of classifications, the output neurons would correlate to that. For example, consider you were using the NN to recognise hand written digits (0-9) then you would likely have 10 output neurons to group those possible outputs.

 The number of hidden neurons, or the number of layers is a lot more problem dependent and is something that needs to be tuned per problem.

If you have ever worked with graph type structures in code before, then setting up a simple network of these neurons is also relatively straight forward given the uniform structure of the network.

The weights

Oh right, yes. So all these neurons are connected, but its a weighted graph, in CS terms. That is, the connection between each node is assigned a weight - as a simple example, if we take the house price dataset as an example, after looking at the data we might determine that the number of rooms is a more significant factor in the end result, so that connection should have a greater weighting than the other input.  Now that is not exactly what happens, but in simple terms it's a good way of thinking about it.

Given that the end goal is for the NN to learn from the data, it really doesn't matter too much what we initialise the weights for all the connections to, so at start-up we can just assign random values. Just think of our NN as a newborn baby - it has no idea how important different inputs are, or how important the different neural cells that fire in response to the stimuli are - everything is just firing off all over the place as they slowly start to learn!

So what next?

Ok, so we have our super-simple neurons, that just mimic a single brain cell and we have them connected in a structured but randomly weighted fashion, so how does the network learn? Well, this is the interesting bit..

When it comes to training the network we need a pretty large dataset to allow it to be able to learn enough to start generalising - but at the same time, we don't want to train it on all the data, as we want to hold some back to test it at the end, just to see just how smart it really is.

In AI terms, we are going to be performing "supervised" learning - this just means that we will train it with a dataset where we know the correct answer, so we can make adjustments based on how well (or badly) the network is doing - this is different to "unsupervised" learning, where we have lots of data, but we don't have the right answer for each data point.

Training: Feed forward


The first step is the "feed-forward" step - this is where we grab the first record from our training data set and feed the inputs into our randomly initialised NN, that input is fed through the all of the neurons in the NN until we get to the output layer and we have the networks attempt at the answer. As the weighting is all random, this is going to be way off (imagine a toddlers guess the first time they ever see a smart phone).

As you can see, the code is really simple, just iterate through the layers calculating the output for each neuron. Now, as this is supervised learning we also have the expected output for the dataset, and this means we can work out how far off the network is and attempt to adjust the weights in the network so that next time it performs a bit better.

Training: Back propogation

This is where the magic, and the maths, comes in.  The TL;DR overview for this is basically, we know how wrong the network was in its guess, so we attempt to update each of the connection weights based on that error, and to do so we use differentiation to work out the direction to go to minimise the error.

This took a while of me staring at equations and racking my brain trying to remember my A-level maths stuff to get this, and even so, I wouldn't like to have to go back and attempt to explain this to my old maths teacher, but I'm happy I have enough of a grasp of what is going on, having now coded it up.

In a very broad, rough overview, here is what we are going to do now:

  1. - Cacluate the squared error of the outputs. That is a fairly simple equation of
    1/2 * (target - output)^2

  2. - From there, we will work out the derivative of this function. A lot of the errors we will start to work with now uses differentiation and the derivatives of the result - and this takes a little bit of calculus, but the basic reason is relatively straight forward if you think about what differentiation is for.

     
  3. - We work back through the network, calculating the errors for every neuron combining the error, derivatives and the weighting (to determine how much a particular connection played in an error, it also needs to be considered when correcting the weighting.

Once we have adjusted the weight for each connection based on the weight, we start again and run the next training data record.  Given the variation in the dataset, the next training record might be quite different to the previous record trained on, so the adjustment might be quite different, so after lots (maybe millions) of iterations over the data, incrementally tweaking and adjusting the weights for the different cases, hopefully it starts to converge on something like a reasonable performance.

Maths fun: Why the derivative of the errors?

So, why is calculating the derivative relevant in adjusting the errors?

If you imagine the network as a function of all the different weights, its pretty complicated, but if we were to reduce this, for the sake of easier visualisation, to just be a 3-d space of possible points (e.g. we have just 3 weights to adjust, and those weights are plotted on a 3-d graph) - now imagine our function plots a graph something like this:


(taken from the wikipedia page on partial derivatives)

The derivative of the function allows us to work out the direction of the slope in the graph from a given point, so with our three current weights (co-ordinates) we can use the derivative to work out the direction in which we need to adjust these weights.

Conclusion

The whole NN in scala is currently on Github, and there isn't much more code than I have included here.  I have found it pretty helpful to code it up from scratch and actually have to think about it - and has been fun coding it up and seeing it train up to getting 90% accuracy on the unseen data set I was using (just a simple two-in two-out dataset, so not that impressive).

Next up I will see how the network performs against the MNIST dataset (like the Hello-World benchmark for machine learning, and is the classification of handwritten digits).

Oh, and if you are interested as to why the images look like a dodgy photocopy, they are photos of original diagrams that I included in my final year dissertation in university, just for old time sake!