From 240559b94133a3d7ccc43804c42ff0e595e6d303 Mon Sep 17 00:00:00 2001 From: Ben Boeckel Date: Mon, 24 Apr 2017 19:54:30 -0400 Subject: [PATCH] iter traits: impl `iter::{Sum, Product}` This adds `Sum` trait for the `MatrixN`, `VectorN`, `Quaternion` structures and the `Product` trait for `MatrixN`, `BasisN` and `Quaternion`. It also add constraints on the `Rotation` and `SquareMatrix` to require the `Product` trait and `VectorSpace` to require `Sum`. --- src/matrix.rs | 29 +++++++++++++++++++++++++++++ src/quaternion.rs | 29 +++++++++++++++++++++++++++++ src/rotation.rs | 30 ++++++++++++++++++++++++++++++ src/structure.rs | 3 +++ src/vector.rs | 29 +++++++++++++++++++++++++++++ tests/matrix.rs | 36 ++++++++++++++++++++++++++++++++++++ tests/quaternion.rs | 10 ++++++++++ tests/vector.rs | 20 ++++++++++++++++++++ 8 files changed, 186 insertions(+) diff --git a/src/matrix.rs b/src/matrix.rs index b45dca2..d2af6f0 100644 --- a/src/matrix.rs +++ b/src/matrix.rs @@ -16,6 +16,7 @@ use rand::{Rand, Rng}; use num_traits::{cast, NumCast}; use std::fmt; +use std::iter; use std::mem; use std::ops::*; use std::ptr; @@ -967,6 +968,34 @@ macro_rules! impl_matrix { fn sub_assign(&mut self, other: $MatrixN) { $(self.$field -= other.$field);+ } } + impl iter::Sum for $MatrixN { + #[inline] + fn sum>(iter: I) -> Self { + iter.fold(Self::zero(), Add::add) + } + } + + impl<'a, S: 'a + BaseFloat> iter::Sum<&'a Self> for $MatrixN { + #[inline] + fn sum>(iter: I) -> Self { + iter.fold(Self::zero(), Add::add) + } + } + + impl iter::Product for $MatrixN { + #[inline] + fn product>(iter: I) -> Self { + iter.fold(Self::identity(), Mul::mul) + } + } + + impl<'a, S: 'a + BaseFloat> iter::Product<&'a Self> for $MatrixN { + #[inline] + fn product>(iter: I) -> Self { + iter.fold(Self::identity(), Mul::mul) + } + } + impl_scalar_ops!($MatrixN { $($field),+ }); impl_scalar_ops!($MatrixN { $($field),+ }); impl_scalar_ops!($MatrixN { $($field),+ }); diff --git a/src/quaternion.rs b/src/quaternion.rs index 2c6189f..92be09e 100644 --- a/src/quaternion.rs +++ b/src/quaternion.rs @@ -13,6 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::iter; use std::mem; use std::ops::*; @@ -187,6 +188,34 @@ impl One for Quaternion { } } +impl iter::Sum for Quaternion { + #[inline] + fn sum>(iter: I) -> Self { + iter.fold(Self::zero(), Add::add) + } +} + +impl<'a, S: 'a + BaseFloat> iter::Sum<&'a Self> for Quaternion { + #[inline] + fn sum>(iter: I) -> Self { + iter.fold(Self::zero(), Add::add) + } +} + +impl iter::Product for Quaternion { + #[inline] + fn product>(iter: I) -> Self { + iter.fold(Self::one(), Mul::mul) + } +} + +impl<'a, S: 'a + BaseFloat> iter::Product<&'a Self> for Quaternion { + #[inline] + fn product>(iter: I) -> Self { + iter.fold(Self::one(), Mul::mul) + } +} + impl VectorSpace for Quaternion { type Scalar = S; } diff --git a/src/rotation.rs b/src/rotation.rs index 4a25101..6769634 100644 --- a/src/rotation.rs +++ b/src/rotation.rs @@ -14,6 +14,7 @@ // limitations under the License. use std::fmt; +use std::iter; use std::ops::*; use structure::*; @@ -33,6 +34,7 @@ pub trait Rotation: Sized + Copy + One where // FIXME: Ugly type signatures - blocked by rust-lang/rust#24092 Self: ApproxEq::Scalar>,

::Scalar: BaseFloat, + Self: iter::Product, { /// Create a rotation to a given direction with an 'up' vector. fn look_at(dir: P::Diff, up: P::Diff) -> Self; @@ -157,6 +159,20 @@ impl From> for Matrix2 { fn from(b: Basis2) -> Matrix2 { b.mat } } +impl iter::Product for Basis2 { + #[inline] + fn product>(iter: I) -> Self { + iter.fold(Basis2 { mat: Matrix2::identity() }, Mul::mul) + } +} + +impl<'a, S: 'a + BaseFloat> iter::Product<&'a Self> for Basis2 { + #[inline] + fn product>(iter: I) -> Self { + iter.fold(Basis2 { mat: Matrix2::identity() }, Mul::mul) + } +} + impl Rotation> for Basis2 { #[inline] fn look_at(dir: Vector2, up: Vector2) -> Basis2 { @@ -263,6 +279,20 @@ impl From> for Quaternion { fn from(b: Basis3) -> Quaternion { b.mat.into() } } +impl iter::Product for Basis3 { + #[inline] + fn product>(iter: I) -> Self { + iter.fold(Basis3 { mat: Matrix3::identity() }, Mul::mul) + } +} + +impl<'a, S: 'a + BaseFloat> iter::Product<&'a Self> for Basis3 { + #[inline] + fn product>(iter: I) -> Self { + iter.fold(Basis3 { mat: Matrix3::identity() }, Mul::mul) + } +} + impl Rotation> for Basis3 { #[inline] fn look_at(dir: Vector3, up: Vector3) -> Basis3 { diff --git a/src/structure.rs b/src/structure.rs index 5d7dc98..b29a1b3 100644 --- a/src/structure.rs +++ b/src/structure.rs @@ -17,6 +17,7 @@ use num_traits::{cast, Float}; use std::cmp; +use std::iter; use std::ops::*; use approx::ApproxEq; @@ -153,6 +154,7 @@ pub trait ElementWise { /// ``` pub trait VectorSpace: Copy + Clone where Self: Zero, + Self: iter::Sum, Self: Add, Self: Sub, @@ -455,6 +457,7 @@ pub trait SquareMatrix where Self::Scalar: BaseFloat, Self: One, + Self: iter::Product, Self: Matrix< // FIXME: Can be cleaned up once equality constraints in where clauses are implemented diff --git a/src/vector.rs b/src/vector.rs index f645f3e..df3f7f8 100644 --- a/src/vector.rs +++ b/src/vector.rs @@ -16,6 +16,7 @@ use rand::{Rand, Rng}; use num_traits::NumCast; use std::fmt; +use std::iter; use std::mem; use std::ops::*; @@ -163,6 +164,20 @@ macro_rules! impl_vector { } } + impl iter::Sum for $VectorN { + #[inline] + fn sum>(iter: I) -> Self { + iter.fold(Self::zero(), Add::add) + } + } + + impl<'a, S: 'a + BaseNum> iter::Sum<&'a Self> for $VectorN { + #[inline] + fn sum>(iter: I) -> Self { + iter.fold(Self::zero(), Add::add) + } + } + impl VectorSpace for $VectorN { type Scalar = S; } @@ -371,6 +386,20 @@ macro_rules! impl_vector_default { } } + impl iter::Sum for $VectorN { + #[inline] + fn sum>(iter: I) -> Self { + iter.fold(Self::zero(), Add::add) + } + } + + impl<'a, S: 'a + BaseNum> iter::Sum<&'a Self> for $VectorN { + #[inline] + fn sum>(iter: I) -> Self { + iter.fold(Self::zero(), Add::add) + } + } + impl VectorSpace for $VectorN { type Scalar = S; } diff --git a/tests/matrix.rs b/tests/matrix.rs index fc3cda3..43940d3 100644 --- a/tests/matrix.rs +++ b/tests/matrix.rs @@ -96,6 +96,18 @@ pub mod matrix2 { assert_eq!(A * B, &A * &B); } + #[test] + fn test_sum_matrix() { + let res: Matrix2 = [A, B, C].iter().sum(); + assert_eq!(res, A + B + C); + } + + #[test] + fn test_product_matrix() { + let res: Matrix2 = [A, B, C].iter().product(); + assert_eq!(res, A * B * C); + } + #[test] fn test_determinant() { assert_eq!(A.determinant(), -2.0f64) @@ -258,6 +270,18 @@ pub mod matrix3 { assert_eq!(A * B, &A * &B); } + #[test] + fn test_sum_matrix() { + let res: Matrix3 = [A, B, C, D].iter().sum(); + assert_eq!(res, A + B + C + D); + } + + #[test] + fn test_product_matrix() { + let res: Matrix3 = [A, B, C, D].iter().product(); + assert_eq!(res, A * B * C * D); + } + #[test] fn test_determinant() {; assert_eq!(A.determinant(), 0.0f64); @@ -615,6 +639,18 @@ pub mod matrix4 { assert_eq!(A * B, &A * &B); } + #[test] + fn test_sum_matrix() { + let res: Matrix4 = [A, B, C, D].iter().sum(); + assert_eq!(res, A + B + C + D); + } + + #[test] + fn test_product_matrix() { + let res: Matrix4 = [A, B, C, D].iter().product(); + assert_eq!(res, A * B * C * D); + } + #[test] fn test_determinant() { assert_eq!(A.determinant(), 0.0f64); diff --git a/tests/quaternion.rs b/tests/quaternion.rs index d2e9a8c..d2fbff5 100644 --- a/tests/quaternion.rs +++ b/tests/quaternion.rs @@ -52,6 +52,16 @@ mod operators { fn test_div() { impl_test_div!(2.0f32, Quaternion::from(Euler { x: Rad(1f32), y: Rad(1f32), z: Rad(1f32) })); } + + #[test] + fn test_iter_product() { + let q1 = Quaternion::from(Euler { x: Rad(2f32), y: Rad(1f32), z: Rad(1f32) }); + let q2 = Quaternion::from(Euler { x: Rad(1f32), y: Rad(2f32), z: Rad(1f32) }); + let q3 = Quaternion::from(Euler { x: Rad(1f32), y: Rad(1f32), z: Rad(2f32) }); + + let res: Quaternion = [q1, q2, q3].iter().product(); + assert_eq!(res, q1 * q2 * q3); + } } mod to_from_euler { diff --git a/tests/vector.rs b/tests/vector.rs index 307a5de..1bee749 100644 --- a/tests/vector.rs +++ b/tests/vector.rs @@ -20,6 +20,7 @@ extern crate cgmath; use cgmath::*; use std::f64; +use std::iter; #[test] fn test_constructor() { @@ -87,6 +88,14 @@ macro_rules! impl_test_rem { ) } +macro_rules! impl_test_iter_sum { + ($VectorN:ident { $($field:ident),+ }, $ty:ty, $s:expr, $v:expr) => ( + let res: $VectorN<$ty> = iter::repeat($v).take($s as usize).sum(); + assert_eq!(res, + $VectorN::new($($v.$field * $s),+)); + ) +} + #[test] fn test_add() { impl_test_add!(Vector4 { x, y, z, w }, 2.0f32, vec4(2.0f32, 4.0, 6.0, 8.0)); @@ -140,6 +149,17 @@ fn test_sum() { assert_eq!(Vector4::new(5.0f64, 6.0f64, 7.0f64, 8.0f64).sum(), 26.0f64); } +#[test] +fn test_iter_sum() { + impl_test_iter_sum!(Vector4 { x, y, z, w }, f32, 2.0f32, vec4(2.0f32, 4.0, 6.0, 8.0)); + impl_test_iter_sum!(Vector3 { x, y, z }, f32, 2.0f32, vec3(2.0f32, 4.0, 6.0)); + impl_test_iter_sum!(Vector2 { x, y }, f32, 2.0f32, vec2(2.0f32, 4.0)); + + impl_test_iter_sum!(Vector4 { x, y, z, w }, usize, 2usize, vec4(2usize, 4, 6, 8)); + impl_test_iter_sum!(Vector3 { x, y, z }, usize, 2usize, vec3(2usize, 4, 6)); + impl_test_iter_sum!(Vector2 { x, y }, usize, 2usize, vec2(2usize, 4)); +} + #[test] fn test_product() { assert_eq!(Vector2::new(1isize, 2isize).product(), 2isize);