Re-thinking the Visitor pattern with Scala, Shapeless & polymorphic functions

In this article I will look at a relatively boilerplate free way to traverse tree structures in Scala, using polymorphic functions along with Shapeless' everything function.

Over the course of my career, a problem that I have had to face fairly repeatedly is dealing with nested tree like structures with arbitrary depth. From XML to directory structures to building data models, nested trees or documents are a common and pretty useful way to model data.

Early in my career (classic Java/J2EE/Spring days) I tackled them using the classic Visitor pattern from the Gang of Four and have probably had more than my fair share of implementing that pattern, then whilst working in Groovy I re-imagined the pattern a little to make it a little more idiomatic (dealing with mostly Maps and Lists) and now I am working in Scala, and once again the problem has arisen.

There are lots of things that Scala handles well - I do generally like its type system, and everyone always raves about the pattern matching (which is undeniably useful), but it has always irked me a bit when dealing with child classes that I have to match on every implementation to do something - I always feel like its something I should be able to do with type classes, and inevitably end up a little sad every time I remember I can't. Let me explain with a quick example, lets imagine we are modeling a structure like XML (I will assume we all know XML, but the format essentially allows you to define nested tree structures of elements - an element can be a complex type e.g. like a directory/object, that holds further children elements, or a simple type e.g. a string element that holds a string).

Above is a basic setup to model a tree structure - we have our sealed trait for the generic element, and we then have a class for the complex element (that is an element that can have further list of element children) and then a couple basic classes for the simple elements (String/Boolean/Double).

Now, when we have a ComplexElement and we want to process its children, a List[Element], ideally type classes would come to our rescue, like this:

Above we have a simple ValidatorTypeClass for which we define our implementations for all the different types we care about, and from there, it looks relatively simple to traverse a nested structure - the type class for the ComplexElement simply iterates through children and recursively passes to the child element type class to handle the logic (note: I will use validation as an example throughout this article, but that is just for the sake of a simple illustration - there are many better ways to perform simple attribute validation in Scala - but helps provide an example context for the problem.)

However, if you run the above code, you will get an error like this:

The reason is, it's looking for an implicit type class to handle the parent type Element (ComplexElement value attribute is defined as List[Element]), which we haven't defined. Sure, we could define that type class ValidatorTypeClass[Element], and simply pattern match the input across all the implemented types, but at that point there's no point having type classes, and you just end up with a big old pattern matching block - which is fine, but it feels kind of verbose, especially when you have to have the blocks repeated throughout the code as you inevitably have to handle the tree structure in several different places/ways.

So I wanted to find a better way, and having written about Shapeless a couple times before, once again..

Enter Shapeless

The good news is, Shapeless has some tools that can help improve this - the bad news is, there isn't really any documentation on some of the features (beyond reading the source code and unit tests) and some of it just doesn't seem to be mentioned anywhere at all! I had previously used a function that Shapeless provides called everywhere - even this function isn't really explicitly called out in the docs, but I stumbled upon it in an article about what was new in Shapeless 2.0 where it was used in an example piece of code without any mention or explanation - everywhere allows in place editing of tree like structures (or any structures really) and was based on the ideas laid out in the Scrap Your Boilerplate (SYB) paper that large parts of the Shapeless library was based on.

As well as everywhere Shapeless also provides a function called everything which is also from the SYB paper, and instead of editing, it lets you simply traverse, or visit generic data structures. It's pretty simple, conceptually, but finding any mention of it in docs or footnotes was hard (I found it reading the source code), so lets go through it.

everything takes three arguments:

The first one is a polymorphic function that we want to process every step of the data structure, combine is a polymorphic function to combine the results, and complex (the third argument above) is our input - in this case the root of our nested data model.

So lets start with our polymorphic function for validating every step (this will be every attribute on each class, including lists, maps and other classes that will then get traversed as well (you can find out more about polymorphic functions and how they are implemented with Shapeless here):

So what is happening here? And why do we have two polymorphic functions? Well lets start with our second one, validates, that is going to be handling the validation. Remember our lovely and simple type class we defined earlier? we are going to use it here, in this polymorphic function we simply define this implicit function that will match on any attribute it finds what has an implicit ValidatorTypeClass in scope, and run the validation (in our simple example, returning a boolean result for whether it passes or fails).

Now, there are also going to be other types in our structure that we essentially want to ignore - they might be simple attributes (Strings, etc) or they might Lists, that we want to continue to traverse, but as a type in itself, we can just pass over. For this we need a polymorphic function which is essentially a No-Op and returns true. As the cases in the polymorphic function are implicits, we need to have the default case in the parent class so it is resolved as a lower priority than our validating implicit.

So, everywhere is going to handle the generic traversal of our data structure, what ever that might look like, and this polymorphic function is going to return a boolean to indicate whether every element in the structure is ok - now as mentioned, we need to combine all these results from our structure.

To do that, we just define another polymorphic function with arity 2 to define how we handle - which in the case of booleans is really very simple:

This combinator will simply combine the booleans, and as soon as one element fails, the overall answer will be false.

And thats it! Shapeless' everywhere handles the boilerplate, and with the addition of those minimal polymorphic functions we don'd need to worry about traversing anything or pattern matching on parent types - so it ends up really quite nice. Nine extra lines of code, and our type class approach works after all!


Footnote 1: Further removing boilerplate

If you found yourself writing code like this a lot, you could further simplify it, by changing our implicit ValidatorTypeClass to a more broad VisitorTypeClass and provide a common set of combinators for the combine polymorphic function, and then all you would need to do each time is provide the specific type class implementation of VisitorTypeClass and it would just work as if by magic.

Footnote 2: A better validation

As mentioned, the validation example was purely illustrative, as its a simple domain to understand, and there are other better ways to perform simple validation (at time of construction, other libraries etc), but if we were to have this perform validation, rather than return booleans, we could look to use something like Validated from Cats - this would allow us to accumulate meaningful failures throughout the traversal. This is really simple to drop in, and all we would need to do is implement the combine polymorphic function for ValidatedNel class:

Thankfully, Cats ValidatedNel is a Semigroup implementation, so it already provides the combine method itself, so all we need to do is call that! (Note: you will need to provide a Semigroup implementation for whatever right hand side you choose to use for Validated, but thats trivial for most types)



0 comments: