In my previous post I looked at a simple way to eliminate conditional expressions by encoding them into the Scala type system. In this follow-up post I want to look at taking this further by using the Scala type system to encode state rules. The aim, as always, is to create code that fails to compile rather than code that fails at runtime. By doing this we also reduce the number of unit tests required.
The Problem
Consider a simple shopping workflow. I collect information about a Basket, the Customer and their Payment Method. Once I have all this in place I can create a processOrder function that completes the workflow process. As a naive starting point, lets encode this as a simple domain model holding optional values:
case class CheckoutWorkflow(basket: Option[Basket], customer: Option[Customer], paymentMethod: Option[PaymentMethod]) object CheckoutWorkflow { def empty = new CheckoutWorkflow(None, None, None) }
Then we need some functions that populate the workflow:
def processBasket(workflow: CheckoutWorkflow): CheckoutWorkflow = { if ( workflow.basket.isDefined ) throw new IllegalStateException("Basket workflow step already processed") // Do some processing... workflow.copy(basket = Some(basket)) } def processCustomer(workflow: CheckoutWorkflow): CheckoutWorkflow = { if ( workflow.customer.isDefined ) throw new IllegalStateException("Customer workflow step already processed") // Do some processing... workflow.copy(customer = Some(customer)) } def processPaymentMethod(workflow: CheckoutWorkflow): CheckoutWorkflow = { if ( workflow.paymentMethod.isDefined ) throw new IllegalStateException("Payment Method workflow step already processed") // Do some processing... workflow.copy(paymentMethod = Some(paymentMethod)) }
Note that each of the above functions has a guard condition to stop them being called multiple times for the same workflow. Each of these guard conditions would require a separate unit test to ensure that it works and to avoid regressions should it be accidentally removed in the future.
Finally, we need the method to process the order. Given our domain model above, this class needs to contain some conditional check to ensure that all the workflow requirements are satisfied before processing of the order can commence. Something like:
def processOrder(workflow: CheckoutWorkflow) = workflow match { case CheckoutWorkflow(Some(basket), Some(customer), Some(paymentMethod)) => { // Do the order processing } case _ => throw new IllegalStateException("Workflow requirements not satisfied") }
None of the above is obviously ideal as there are a number of places that have the potential to error at runtime. The conditionals pollute our code with non-business logic. Also, we have to write good unit tests to ensure all the conditionals are working correctly and have not been accidentally removed. Even then, any client may call our code having not met the requirements encoded in the conditionals and they will receive a runtime error. Let's hope they unit test as thoroughly as we do! Surely we can do better than this?
A Less Than Ideal Solution
Well, we could use the approach outlined in my previous post and use domain model extensions:
case class WorkflowWithBasket(basket: Basket) case class WorkflowWithBasketAndCustomer(basket: Basket, customer: Customer) case class WorkflowWithAllRequirements(basket: Basket, customer: Customer, paymentMethod: PaymentMethod) def processOrder(workflow: WorkflowWithAllRequirements) = { // Do the order processing }
While this does allow removal of all the conditionals and associated tests, it unfortunatley also reduces the flexibility of our model quite significantly in that the order that the workflow must be processed is now encoded into the domain model. Not ideal. We'd like to keep the flexibility from the first solution but in a type safe way. Is there a way that we can encode the requirements into the type system?
Levaraging The Type System
First, lets consider what we want to achieve. Our aim to encode unsatisfied and satisfied workflow requirements and only allow methods to be called when the correct combinations are set. So, let's first encode the concept of requirements:
trait BasketRequirement case object UnsatisfiedBasketRequirement extends BasketRequirement trait CustomerRequirement case object UnsatisfiedCustomerRequirement extends CustomerRequirement trait PaymentMethodRequirement case object UnsatisfiedPaymentMethodRequirements extends PaymentMethodRequirement
Here we have defined requirement traits for each of the different workflow stages. We also defined case objects to represent the unsatisfied state of each requirement. Next we need to indicate the satisfied states, which are our actual domain object classes:
case class Basket(items: List[LineItem]) extends BasketRequirement case class Customer(id: String) extends CustomerRequirement case class PaymentMethod(paymentType: PaymentType) extends PaymentMethodRequirement
Next job is to make sure that our workflow object can represent these requirements and be strongly typed on either the satisfied or unsatisfied state. We do this by adding type bounds to each of the workflow types. This also allows us to eliminate the need for the Option types. We also define an 'unsatisfied' instance as the starting point for our workflow:
case class CheckoutWorkflow[B <: BasketRequirement, C <: CustomerRequirement, PM <: PaymentMethodRequirement] (basket: B, customer: C, paymentMethod: PM) object CheckoutWorkflow { val unsatisfied = CheckoutWorkflow(UnsatisfiedBasketRequirement, UnsatisfiedCustomerRequirement, UnsatisfiedPaymentMethodRequirements) }
Now we need the functions that actually process each individual workflow stage. Note how each one defines type bounded parameters for the things it doesn't care about. However, for the stage that it actually manipulates it requires that it is called with the Unsatisfied type and returns the Satisfied type. Thus, you can no longer call any of these methods with a workflow that already has that stage satisfied: the compiler won't allow it:
def processBasket[C <: CustomerRequirement, PM <: PaymentMethodRequirement] (workflow: CheckoutWorkflow[UnsatisfiedBasketRequirement.type, C, PM]): CheckoutWorkflow[Basket, C, PM] = { // Do some processing... workflow.copy(basket = basket) } def processCustomer[B <: BasketRequirement, PM <: PaymentMethodRequirement] (workflow: CheckoutWorkflow[B, UnsatisfiedCustomerRequirement.type, PM]): CheckoutWorkflow[B, Customer, PM] = { // Do some processing... workflow.copy(customer = customer) } def processPaymentMethod[B <: BasketRequirement, C <: CustomerRequirement] (workflow: CheckoutWorkflow[B, C, UnsatisfiedPaymentMethodRequirements.type]): CheckoutWorkflow[B, C, PaymentMethod] = { // Do some processing... workflow.copy(paymentMethod = paymentMethod) }
Finally, our processOrder function becomes super simple. You just can't call it any more unless it has no Unsatisfied types:
def processOrder(workflow: CheckoutWorkflow[Basket, Customer, PaymentMethod]) = { // Process the order }
One observation that was made to me by someone (a Java dev) who looked at my initial draft of the code was that the final solution looks more complex due to all the type bounds and that there's actually more classes due to the need to define the requirement traits and the Unsatisfied state objects. However, don't forget that this solution eliminates at least four conditional blocks and associated unit tests, simplifies others and also possibly reduces the number of tests that clients need to write on their code as well. Also, there's no possibility of runtime failure. If the code compiles then we have a much higher confidence that it will work. Well worth a tiny bit more type complexity in my book.