Re-thinking the Visitor pattern with Scala, Shapeless & polymorphic functions
In this article I will look at a relatively boilerplate free way to traverse tree structures in Scala, using polymorphic functions along with Shapeless' everything function.Over the course of my career, a problem that I have had to face fairly repeatedly is dealing with nested tree like structures with arbitrary depth. From XML to directory structures to building data models, nested trees or documents are a common and pretty useful way to model data.
Early in my career (classic Java/J2EE/Spring days) I tackled them using the classic Visitor pattern from the Gang of Four and have probably had more than my fair share of implementing that pattern, then whilst working in Groovy I re-imagined the pattern a little to make it a little more idiomatic (dealing with mostly Maps and Lists) and now I am working in Scala, and once again the problem has arisen.
There are lots of things that Scala handles well - I do generally like its type system, and everyone always raves about the pattern matching (which is undeniably useful), but it has always irked me a bit when dealing with child classes that I have to match on every implementation to do something - I always feel like its something I should be able to do with type classes, and inevitably end up a little sad every time I remember I can't. Let me explain with a quick example, lets imagine we are modeling a structure like XML (I will assume we all know XML, but the format essentially allows you to define nested tree structures of elements - an element can be a complex type e.g. like a directory/object, that holds further children elements, or a simple type e.g. a string element that holds a string).
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
sealed trait Element | |
sealed trait SimpleElement[A] extends Element { | |
def value: A | |
} | |
case class ComplexElement (value: List[Element]) extends Element | |
case class TextElement (value: String) extends SimpleElement[String] | |
case class NumberElement (value: Double) extends SimpleElement[Double] | |
case class BooleanElement (value: Boolean) extends SimpleElement[Boolean] |
Now, when we have a ComplexElement and we want to process its children, a List[Element], ideally type classes would come to our rescue, like this:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
sealed trait ValidatorTypeClass[A] { | |
def validate(a: A): Boolean | |
} | |
object ValidatorTypeClass { | |
def validateElement[A](a: A)(implicit v: ValidatorTypeClass[A]) = v.validate(a) | |
implicit def stringElementValidator = new ValidatorTypeClass[String] { | |
override def validate(a: String): Boolean = ??? //validation logic for strings | |
} | |
implicit def numberElementValidator = new ValidatorTypeClass[Double] { | |
override def validate(a: Double): Boolean = ??? //validation logic for numbers | |
} | |
implicit def booleanElementValidator = new ValidatorTypeClass[Boolean] { | |
override def validate(a: Boolean): Boolean = ??? //validation logic for booleans | |
} | |
implicit def complexElementValidator = new ValidatorTypeClass[ComplexElement] { | |
override def validate(a: ComplexElement): Boolean = a.value.forall(validateElement) | |
} | |
} | |
import ValidatorTypeClass._ | |
val complex = ComplexElement( | |
value = List( | |
StringElement(value = "first element"), | |
StringElement(value = "second element") | |
) | |
) | |
validateElement(complex) |
However, if you run the above code, you will get an error like this:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
could not find implicit value for parameter v: ValidatorTypeClass[Element] | |
override def validate(a: ComplexElement): Boolean = a.value.forall(validateElement) |
So I wanted to find a better way, and having written about Shapeless a couple times before, once again..
Enter Shapeless
The good news is, Shapeless has some tools that can help improve this - the bad news is, there isn't really any documentation on some of the features (beyond reading the source code and unit tests) and some of it just doesn't seem to be mentioned anywhere at all! I had previously used a function that Shapeless provides called everywhere - even this function isn't really explicitly called out in the docs, but I stumbled upon it in an article about what was new in Shapeless 2.0 where it was used in an example piece of code without any mention or explanation - everywhere allows in place editing of tree like structures (or any structures really) and was based on the ideas laid out in the Scrap Your Boilerplate (SYB) paper that large parts of the Shapeless library was based on.As well as everywhere Shapeless also provides a function called everything which is also from the SYB paper, and instead of editing, it lets you simply traverse, or visit generic data structures. It's pretty simple, conceptually, but finding any mention of it in docs or footnotes was hard (I found it reading the source code), so lets go through it.
everything takes three arguments:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
everything(validates)(combine)(complex) |
So lets start with our polymorphic function for validating every step (this will be every attribute on each class, including lists, maps and other classes that will then get traversed as well (you can find out more about polymorphic functions and how they are implemented with Shapeless here):
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
sealed trait DefaultValidation extends Poly1 { | |
implicit def default[T] = at[T](x => true) | |
} | |
object validates extends DefaultValidation { | |
implicit def caseValidated[A](implicit v: ValidatorTypeClass[A]) = at[A](x => v.validate(x)) | |
} |
Now, there are also going to be other types in our structure that we essentially want to ignore - they might be simple attributes (Strings, etc) or they might Lists, that we want to continue to traverse, but as a type in itself, we can just pass over. For this we need a polymorphic function which is essentially a No-Op and returns true. As the cases in the polymorphic function are implicits, we need to have the default case in the parent class so it is resolved as a lower priority than our validating implicit.
So, everywhere is going to handle the generic traversal of our data structure, what ever that might look like, and this polymorphic function is going to return a boolean to indicate whether every element in the structure is ok - now as mentioned, we need to combine all these results from our structure.
To do that, we just define another polymorphic function with arity 2 to define how we handle - which in the case of booleans is really very simple:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
object combine extends Poly2 { | |
implicit def caseValidation = at[Boolean, Boolean] (_ && _) | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
val complex = ComplexElement( | |
value = List( | |
StringElement(value = "first element"), | |
StringElement(value = "second element") | |
) | |
) | |
import shapeless._ | |
sealed trait DefaultValidation extends Poly1 { | |
implicit def default[T] = at[T](x => true) | |
} | |
object validates extends DefaultValidation { | |
implicit def caseValidated[A](implicit v: ValidatorTypeClass[A]) = at[A](x => v.validate(x)) | |
} | |
object combine extends Poly2 { | |
implicit def caseValidation = at[Boolean, Boolean] (_ && _) | |
} | |
sealed trait ValidatorTypeClass[A] { | |
def validate(a: A): Boolean | |
} | |
object ValidatorTypeClass { | |
def validateElement[A](a: A)(implicit v: ValidatorTypeClass[A]) = v.validate(a) | |
implicit def stringElementValidator = new ValidatorTypeClass[String] { | |
override def validate(a: String): Boolean = true //validation logic for strings | |
} | |
implicit def numberElementValidator = new ValidatorTypeClass[Double] { | |
override def validate(a: Double): Boolean = true //validation logic for numbers | |
} | |
implicit def booleanElementValidator = new ValidatorTypeClass[Boolean] { | |
override def validate(a: Boolean): Boolean = true //validation logic for booleans | |
} | |
implicit def complexElementValidator = new ValidatorTypeClass[ComplexElement] { | |
override def validate(a: ComplexElement): Boolean = true | |
} | |
} | |
everything(validates)(combine)(complex) |
Footnote 1: Further removing boilerplate
If you found yourself writing code like this a lot, you could further simplify it, by changing our implicit ValidatorTypeClass to a more broad VisitorTypeClass and provide a common set of combinators for the combine polymorphic function, and then all you would need to do each time is provide the specific type class implementation of VisitorTypeClass and it would just work as if by magic.Footnote 2: A better validation
As mentioned, the validation example was purely illustrative, as its a simple domain to understand, and there are other better ways to perform simple validation (at time of construction, other libraries etc), but if we were to have this perform validation, rather than return booleans, we could look to use something like Validated from Cats - this would allow us to accumulate meaningful failures throughout the traversal. This is really simple to drop in, and all we would need to do is implement the combine polymorphic function for ValidatedNel class:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
object combineFieldValues extends Poly2 { | |
implicit def caseValidation = at[ValidatedNel[String, Boolean], ValidatedNel[String, Boolean]] ({ | |
case (a,b) => a.combine(b) | |
}) |
0 comments: