From 0795fc6d8fe7cfc440456bbb514fa8771f9f907c Mon Sep 17 00:00:00 2001
From: danieljankowski <daniel.jankowski@rub.de>
Date: Mon, 18 Feb 2019 15:20:26 +0100
Subject: [PATCH] Added: matrix multiplication

---
 README.md       |  2 +-
 algebra.go      | 70 +++++++++++++++++++++++++++++++++++++++++++++++++
 algebra_test.go | 65 ++++++++++++++++++++++++++++++++++++++++++---
 matrix.go       | 28 ++++++++++++++++++++
 matrix_test.go  | 34 ++++++++++++++++++++++++
 5 files changed, 195 insertions(+), 4 deletions(-)

diff --git a/README.md b/README.md
index 29c044d..84bfd01 100644
--- a/README.md
+++ b/README.md
@@ -14,7 +14,7 @@ A go package for scientific matrix operations.
   - [x] Set values to matrix
   - [x] AddMatrix
   - [ ] SubMatrix
-  - [ ] MulMatrix
+  - [x] MulMatrix
   - [ ] DivMatrix
   - [ ] InvertMatrix
   - [ ] Partial gauss
diff --git a/algebra.go b/algebra.go
index cdd886a..830fccf 100644
--- a/algebra.go
+++ b/algebra.go
@@ -1,5 +1,9 @@
 package gomatrix
 
+import (
+	"math/big"
+)
+
 // AddMatrix adds two matrices
 //
 // This function adds a matrix to the matrix object. The result will be saved
@@ -24,3 +28,69 @@ func (f *F2) AddMatrix(m *F2) *F2 {
 	// return the matrix
 	return f
 }
+
+// MulMatrix multiplies matrix f with matrix m
+//
+// This functions multiplies matrix fxm. M could be a Nx1 matrix for a vector.
+// If the matrices cannot be multiplied, nil is returned and f is not
+// modified. If the multiplication was successful, the result is stored
+// in f and returned.
+//
+// @param *F2 m The matrix that is used for the multiplication
+//
+// @return *F2
+func (f *F2) MulMatrix(m *F2) *F2 {
+	// if the dimensions do not fit for a multiplication...
+	if f.N != m.M {
+		// ...retrun an error
+		return nil
+	}
+
+	// create the result matrix
+	result := NewF2(f.N, m.M)
+
+	// iterate through the rows of f
+	for i, row := range f.Rows {
+		// iterate through the columns of m
+		for j := 0; j < f.M; j++ {
+			// get the column from the second matrix
+			col := m.GetCol(j)
+
+			// multiply the vectors
+			intermediateResult := big.NewInt(0).And(row, col)
+
+			// sum up the values of the vectors
+			resultBit := addBits(intermediateResult)
+
+			// set the resulting bit to the result matrix
+			result.Rows[i].SetBit(result.Rows[i], j, resultBit)
+		}
+	}
+
+	// save the result matrix in f
+	f.Rows = result.Rows
+
+	// return the result
+	return result
+}
+
+// addBits sums up all bits of a given number
+//
+// @param *big.Int number The number to process
+//
+// @return uint
+func addBits(number *big.Int) uint {
+	// get the bit length of the number
+	bitLen := number.BitLen()
+
+	// initialize the result
+	result := uint(0)
+
+	// iterate through the bits
+	for i := 0; i < bitLen; i++ {
+		result ^= number.Bit(i)
+	}
+
+	// return the result
+	return result
+}
diff --git a/algebra_test.go b/algebra_test.go
index c32132f..31a24c2 100644
--- a/algebra_test.go
+++ b/algebra_test.go
@@ -7,7 +7,7 @@ import (
 	"github.com/stretchr/testify/assert"
 )
 
-func TestMatrixAdd(t *testing.T) {
+func TestAddMatrix(t *testing.T) {
 	tests := []struct {
 		description    string
 		matrixA        *F2
@@ -21,7 +21,7 @@ func TestMatrixAdd(t *testing.T) {
 			expectedMatrix: NewF2(2, 2).Set([]*big.Int{big.NewInt(0), big.NewInt(3)}),
 		},
 		{
-			description:    "success",
+			description:    "invalid addition",
 			matrixA:        NewF2(2, 2).Set([]*big.Int{big.NewInt(2), big.NewInt(1)}),
 			matrixB:        NewF2(2, 3).Set([]*big.Int{big.NewInt(2), big.NewInt(2)}),
 			expectedMatrix: nil,
@@ -37,6 +37,65 @@ func TestMatrixAdd(t *testing.T) {
 			continue
 		}
 
-		assert.Equal(t, true, test.expectedMatrix.IsEqual(test.matrixA))
+		assert.True(t, test.matrixA.IsEqual(test.expectedMatrix))
+	}
+}
+
+func TestMulMatrix(t *testing.T) {
+	tests := []struct {
+		description    string
+		matrixA        *F2
+		matrixB        *F2
+		expectedMatrix *F2
+	}{
+		{
+			description:    "success",
+			matrixA:        NewF2(2, 2).Set([]*big.Int{big.NewInt(2), big.NewInt(1)}),
+			matrixB:        NewF2(2, 2).Set([]*big.Int{big.NewInt(1), big.NewInt(3)}),
+			expectedMatrix: NewF2(2, 2).Set([]*big.Int{big.NewInt(3), big.NewInt(1)}),
+		},
+		{
+			description:    "invalid multiplication",
+			matrixA:        NewF2(2, 2).Set([]*big.Int{big.NewInt(2), big.NewInt(1)}),
+			matrixB:        NewF2(2, 3).Set([]*big.Int{big.NewInt(1), big.NewInt(3)}),
+			expectedMatrix: nil,
+		},
+	}
+
+	for _, test := range tests {
+		test.matrixA.MulMatrix(test.matrixB)
+
+		assert.IsType(t, test.expectedMatrix, test.matrixA)
+
+		if test.expectedMatrix == nil {
+			continue
+		}
+
+		assert.True(t, test.matrixA.IsEqual(test.expectedMatrix))
+	}
+}
+
+func TestAddBits(t *testing.T) {
+	tests := []struct {
+		description    string
+		number         *big.Int
+		expectedResult uint
+	}{
+		{
+			description:    "two bits 1",
+			number:         big.NewInt(3),
+			expectedResult: 0,
+		},
+		{
+			description:    "one bit 1",
+			number:         big.NewInt(4),
+			expectedResult: 1,
+		},
+	}
+
+	for _, test := range tests {
+		result := addBits(test.number)
+
+		assert.Equal(t, test.expectedResult, result)
 	}
 }
diff --git a/matrix.go b/matrix.go
index eb89fce..eb32b70 100644
--- a/matrix.go
+++ b/matrix.go
@@ -274,3 +274,31 @@ func (f *F2) PermuteCols() *F2 {
 	// return the permutation matrix
 	return permutationMatrix
 }
+
+// GetCol returns the column at index i
+//
+// This function returns the column as big.Int after the index is verified.
+// If an invalid index is used, the function returns nil.
+//
+// @param int i The index for the column
+//
+// @return *big.Int
+func (f *F2) GetCol(i int) *big.Int {
+	// check for input parameters
+	if i < 0 || i >= f.M {
+		// return nil
+		return nil
+	}
+
+	// initialize the output big.Int
+	output := big.NewInt(0)
+
+	// iterate through the rows
+	for j, row := range f.Rows {
+		// the the corresponding bit
+		output.SetBit(output, j, row.Bit(i))
+	}
+
+	// return the result
+	return output
+}
diff --git a/matrix_test.go b/matrix_test.go
index c4b034d..85324bb 100644
--- a/matrix_test.go
+++ b/matrix_test.go
@@ -340,3 +340,37 @@ func TestPermuteCols(t *testing.T) {
 		assert.NotNil(t, permMat)
 	}
 }
+
+func TestGetCol(t *testing.T) {
+	tests := []struct {
+		description    string
+		matrixA        *F2
+		colIndex       int
+		expectedResult *big.Int
+	}{
+		{
+			description:    "first column of a 2x2 matrix",
+			matrixA:        NewF2(2, 2).Set([]*big.Int{big.NewInt(2), big.NewInt(1)}),
+			colIndex:       0,
+			expectedResult: big.NewInt(2),
+		},
+		{
+			description:    "second column of a 2x2 matrix",
+			matrixA:        NewF2(2, 2).Set([]*big.Int{big.NewInt(2), big.NewInt(1)}),
+			colIndex:       1,
+			expectedResult: big.NewInt(1),
+		},
+		{
+			description:    "invalid parameter",
+			matrixA:        NewF2(2, 2).Set([]*big.Int{big.NewInt(2), big.NewInt(1)}),
+			colIndex:       2,
+			expectedResult: nil,
+		},
+	}
+
+	for _, test := range tests {
+		result := test.matrixA.GetCol(test.colIndex)
+
+		assert.Equal(t, test.expectedResult, result)
+	}
+}
-- 
GitLab