Codementor Events

A simple quaternion library or a lesson in Operator Overloading

Published Jun 03, 2023

Just for fun and to try out Rust operator overloading capabilities, I decided to write a quaternion library. If you don’t know what quaternions are, they are basically complex numbers on steroids. If you want to know more, there is an interesting wikipediaarticle on them.

Instead of having one imaginary factor, quaternions have 3, namely i,j and k.

The main rule to follow is: _ i2 = j2 = k2 = ijk =-1_

Open your terminal on an empty directory and type:

cargo new quaternions
cd quaternions

Now open main.rs in the src directory in your favourite IDE, and type this preliminary stuff:

use std::fmt::{Debug, Display};
use std::ops::{Add, Mul, Div,Sub, Neg};
use std::cmp::{PartialEq,PartialOrd};

The Quaternion struct looks as follows:

struct Quaternion<T> where T:Add<Output = T> 
                           + Mul<Output = T> 
                           + Div<Output = T> 
                           + Sub<Output = T> 
                           + PartialEq 
                           + PartialOrd 
                           + Copy 
                           + Debug 
{
    r: T,
    i: T,
    j: T,
    k: T,
}

As you can see, the Quaternion is a simple struct with loads of type constraints we will be needing later for the operator overloading we will be doing.

For example the Add<Output=T> means that type T has to implement the Add trait which also outputs a result of type T. The same goes for many of the other constraints.

The PartialEq and PartialOrd are implemented so we can use < or == operators on our Quaternion structs.

We will also implement the Debug trait for our Quaternions so we can print them out:

impl<T> Debug for Quaternion<T> where T:Add<Output = T> 
                           + Mul<Output = T> 
                           + Div<Output = T> 
                           + Sub<Output = T> 
                           + PartialEq 
                           + PartialOrd 
                           + Copy 
                           + Debug
                           + Display {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "{} + {}i + {}j + {}k", self.r, self.i, self.j, self.k)
    }
}

This quite a rough but workable implementation, and could be refined by for example replacing the plus signs in case some of the number are negative or eliminating a number if it’s zero. I leave that as an exercise to the reader.

One special note is the fact that we have used an extra Display constraint to make sure we can print our types.

The Quaternion implementation is quite straightforward:

impl<T> Quaternion<T> where T:Add<Output = T> 
                           + Mul<Output = T> 
                           + Div<Output = T> 
                           + Sub<Output = T> 
                           + Neg<Output = T>
                           + PartialEq 
                           + PartialOrd 
                           + Copy 
                           + Debug
                           + Default
{
    fn new<Q>(r: Q, i: Q, j: Q, k: Q) -> Quaternion<Q> where Q:Add<Output = Q> 
                                                             + Mul<Output = Q> 
                                                             + Div<Output = Q> 
                                                             + Sub<Output = Q> 
                                                             + PartialEq 
                                                             + PartialOrd 
                                                             + Copy 
                                                             + Debug
                                                             + Default 
    {
        Quaternion { r, i, j, k }
    }

    fn conjugate(&self)->Quaternion<T> {
        Quaternion {
            r: self.r,
            i: -self.i,
            j: -self.j,
            k: -self.k,
        }
    }

    fn magnitude(&self)->T {
        self.r * self.r + self.i * self.i + self.j * self.j + self.k * self.k
    }

    fn is_null(&self)->bool {
        self.r == T::default() && self.i ==T::default() && self.j == T::default() && self.k == T::default()
    }
    
}

Some explanation:

  • A constructor called new(). Notice how new() is a generic method which also has to get this list of type constraints.
  • Furthermore we have a conjugate() method which negates the coefficients of i, j, and k to produce a conjugate. This is later needed to divide to quaternions.
  • The magnitude() method basically calculates the square of the quaternion, because that is all we need in this example.
  • And we need the is_null() method to see if the quaternion is 0, also important for division.

Notice that I added a Default constraint, so we can easily test for a zero or zero-like value in the is_null() method.

The first operation we will implement is the Neg trait, which will negate the Quaternion by negating all of its coefficients.

The code is as follows:

impl<T> Neg for Quaternion<T> where T:Add<Output = T> 
                           + Mul<Output = T> 
                           + Div<Output = T> 
                           + Sub<Output = T>
                           + Neg<Output = T> 
                           + PartialEq 
                           + PartialOrd 
                           + Copy 
                           + Debug{
    type Output = Self;
    fn neg(self) -> Self::Output {
        Quaternion {
            r: -self.r,
            i: -self.i,
            j: -self.j,
            k: -self.k,
        }
    }
}

A few notes:

  • The trait is generic with a list of type constraints. Not that I added the Neg constraint, the generic type T has to be negatable.
  • The statement Output=Self,, means that the operation will produces the same kind of Quaternion as it worked on.
  • Thirdly we have the neg() method, which returns Self::Output, i.e. the same type as it worked on. The code inside neg() is quite self-explanatory.

The Add trait is build along much the same lines as the Neg trait:

impl<T> Add for Quaternion<T> where T:Add<Output = T> 
                           + Mul<Output = T> 
                           + Div<Output = T> 
                           + Sub<Output = T> 
                           + PartialEq 
                           + PartialOrd 
                           + Copy 
                           + Debug{
     type Output = Self;
     fn add(self, rhs: Self) -> Self::Output {
        Quaternion {
            r: self.r + rhs.r,
            i: self.i + rhs.i,
            j: self.j + rhs.j,
            k: self.k + rhs.k,
        }
     }
}

Notice that the add() method here has an extra method, since an add operation operates on two operands. The ‘rhs’ stands for right hand side. For the rest I think the code is quite clear.

The Sub trait (which is for subtrating to Quaternions, to be clear) looks much the same as the Add trait:

impl<T> Sub for Quaternion<T> where T:Add<Output = T> 
                           + Mul<Output = T> 
                           + Div<Output = T> 
                           + Sub<Output = T> 
                           + PartialEq 
                           + PartialOrd 
                           + Copy 
                           + Debug{
    type Output = Self;
    fn sub(self, rhs: Self) -> Self::Output {
        Quaternion {
            r: self.r - rhs.r,
            i: self.i - rhs.i,
            j: self.j - rhs.j,
            k: self.k - rhs.k,
        }
    }
}

Also the Mul trait is very similar to the previous two traits. Notice that, like complex numbers, multiplication is not straightforward, hence the somewhat more complex calculations:

impl<T> Mul for Quaternion<T> where T:Add<Output = T> 
                           + Mul<Output = T> 
                           + Div<Output = T> 
                           + Sub<Output = T> 
                           + PartialEq 
                           + PartialOrd 
                           + Copy 
                           + Debug {
    type Output = Self;
    fn mul(self, rhs: Self) -> Self::Output {
        Quaternion {
            r: self.r * rhs.r - self.i * rhs.i - self.j * rhs.j - self.k * rhs.k,
            i: self.r * rhs.i + self.i * rhs.r + self.j * rhs.k - self.k * rhs.j,
            j: self.r * rhs.j - self.i * rhs.k + self.j * rhs.r + self.k * rhs.i,
            k: self.r * rhs.k + self.i * rhs.j - self.j * rhs.i + self.k * rhs.r,
        }
    }
}

For the Div trait, you must know that for two quaternions q1 and q2 the division q1/q2 is defined as follows: q1*(q2*/|q2|2) where q2* is the conjugate of q2

Hence the Div trait looks like this:

impl<T> Div for Quaternion<T> where T:Add<Output = T> 
                           + Mul<Output = T> 
                           + Div<Output = T> 
                           + Sub<Output = T>
                           + Neg<Output = T> 
                           + PartialEq 
                           + PartialOrd 
                           + Copy 
                           + Debug
                           + Default {
    type Output = Self;
    fn div(self, rhs: Self) -> Self::Output {
        let mut q = Quaternion::<T>::new(T::default(), T::default(), T::default(), T::default());
        if (rhs.is_null()) {
            panic!("Division by zero");
        }
        let rhs_conjugate = rhs.conjugate();
        let rhs_magnitude = rhs.magnitude();
        q = self * rhs_conjugate;
        q.r = q.r / rhs_magnitude;
        q.i = q.i / rhs_magnitude;
        q.j = q.j / rhs_magnitude;
        q.k = q.k / rhs_magnitude;
        q
    }
}

Two notes:

  1. The formula is worked into the div() method.
  2. Note that we check for a zero divisor first and panic if there is one.

The PartialEq does an element-wise comparison to see if two quaternions are equal:

impl<T> PartialEq for Quaternion<T> where T:Add<Output = T> 
                                          + Mul<Output = T> 
                                          + Div<Output = T> 
                                          + Sub<Output = T> 
                                          + PartialEq 
                                          + PartialOrd 
                                          + Copy 
                                          + Debug
{
    fn eq(&self, other: &Self) -> bool {
        self.r == other.r && self.i == other.i && self.j == other.j && self.k == other.k
    }
}

Again, the code is quite straightforward

The PartialOrd trait is used to make operations like < or >= possible. This is done by an elementwise comparison. I include it here for the sake of completeness

impl<T> PartialOrd for Quaternion<T> where T:Add<Output = T> 
                                           + Mul<Output = T> 
                                           + Div<Output = T> 
                                           + Sub<Output = T> 
                                           + PartialEq 
                                           + PartialOrd 
                                           + Copy 
                                           + Debug
{
    fn ge(&self, other: &Self) -> bool {
        self.r >= other.r && self.i >= other.i && self.j >= other.j && self.k >= other.k
    }

    fn gt(&self, other: &Self) -> bool {
        self.r > other.r && self.i > other.i && self.j > other.j && self.k > other.k
    }

    fn le(&self, other: &Self) -> bool {
        self.r <= other.r && self.i <= other.i && self.j <= other.j && self.k <= other.k
    }

    fn lt(&self, other: &Self) -> bool {
        self.r < other.r && self.i < other.i && self.j < other.j && self.k < other.k
    }
    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
        if self.r == other.r && self.i == other.i && self.j == other.j && self.k == other.k {
            Some(std::cmp::Ordering::Equal)
        } else if self.r > other.r && self.i > other.i && self.j > other.j && self.k > other.k {
            Some(std::cmp::Ordering::Greater)
        } else if self.r < other.r && self.i < other.i && self.j < other.j && self.k < other.k {
            Some(std::cmp::Ordering::Less)
        } else {
            None
        }
    }
}

Also quite clear, le stands for ‘less than or equal’ and ‘lt’ stands for ‘less than’ etc..

Now we can test these:

fn main() {
    let q1 = Quaternion::<f64>::new(1.0, 2.0, 3.0, 4.0);
    let q2 = Quaternion::<f64>::new(5.0, 6.0, 7.0, 8.0);
    let q3 = q1 + q2;
    let q4 = q1 * q2;
    let q5 = q1 / q2;
    let q6 = q1 - q2;
    println!("{:?}", q3);
    println!("{:?}", q4);
    println!("{:?}", q5);
    println!("{:?}", q6);
}

Basically, we instantiate two Quaternion<f64> structures, and try out the different operators on them and print them out.

Operator overloading was surprisingly easy to implement in Rust. However, what bugs me is the long constraint lists we have to specify every time, I will try and see if I can find a shortcut for that.

Apart from that little nuisance, the implementation is quite clear and elegant I think

Discover and read more posts from Iede Snoek
get started
post commentsBe the first to share your opinion
Show more replies