From 0795fc6d8fe7cfc440456bbb514fa8771f9f907c Mon Sep 17 00:00:00 2001 From: danieljankowski <daniel.jankowski@rub.de> Date: Mon, 18 Feb 2019 15:20:26 +0100 Subject: [PATCH] Added: matrix multiplication --- README.md | 2 +- algebra.go | 70 +++++++++++++++++++++++++++++++++++++++++++++++++ algebra_test.go | 65 ++++++++++++++++++++++++++++++++++++++++++--- matrix.go | 28 ++++++++++++++++++++ matrix_test.go | 34 ++++++++++++++++++++++++ 5 files changed, 195 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 29c044d..84bfd01 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ A go package for scientific matrix operations. - [x] Set values to matrix - [x] AddMatrix - [ ] SubMatrix - - [ ] MulMatrix + - [x] MulMatrix - [ ] DivMatrix - [ ] InvertMatrix - [ ] Partial gauss diff --git a/algebra.go b/algebra.go index cdd886a..830fccf 100644 --- a/algebra.go +++ b/algebra.go @@ -1,5 +1,9 @@ package gomatrix +import ( + "math/big" +) + // AddMatrix adds two matrices // // This function adds a matrix to the matrix object. The result will be saved @@ -24,3 +28,69 @@ func (f *F2) AddMatrix(m *F2) *F2 { // return the matrix return f } + +// MulMatrix multiplies matrix f with matrix m +// +// This functions multiplies matrix fxm. M could be a Nx1 matrix for a vector. +// If the matrices cannot be multiplied, nil is returned and f is not +// modified. If the multiplication was successful, the result is stored +// in f and returned. +// +// @param *F2 m The matrix that is used for the multiplication +// +// @return *F2 +func (f *F2) MulMatrix(m *F2) *F2 { + // if the dimensions do not fit for a multiplication... + if f.N != m.M { + // ...retrun an error + return nil + } + + // create the result matrix + result := NewF2(f.N, m.M) + + // iterate through the rows of f + for i, row := range f.Rows { + // iterate through the columns of m + for j := 0; j < f.M; j++ { + // get the column from the second matrix + col := m.GetCol(j) + + // multiply the vectors + intermediateResult := big.NewInt(0).And(row, col) + + // sum up the values of the vectors + resultBit := addBits(intermediateResult) + + // set the resulting bit to the result matrix + result.Rows[i].SetBit(result.Rows[i], j, resultBit) + } + } + + // save the result matrix in f + f.Rows = result.Rows + + // return the result + return result +} + +// addBits sums up all bits of a given number +// +// @param *big.Int number The number to process +// +// @return uint +func addBits(number *big.Int) uint { + // get the bit length of the number + bitLen := number.BitLen() + + // initialize the result + result := uint(0) + + // iterate through the bits + for i := 0; i < bitLen; i++ { + result ^= number.Bit(i) + } + + // return the result + return result +} diff --git a/algebra_test.go b/algebra_test.go index c32132f..31a24c2 100644 --- a/algebra_test.go +++ b/algebra_test.go @@ -7,7 +7,7 @@ import ( "github.com/stretchr/testify/assert" ) -func TestMatrixAdd(t *testing.T) { +func TestAddMatrix(t *testing.T) { tests := []struct { description string matrixA *F2 @@ -21,7 +21,7 @@ func TestMatrixAdd(t *testing.T) { expectedMatrix: NewF2(2, 2).Set([]*big.Int{big.NewInt(0), big.NewInt(3)}), }, { - description: "success", + description: "invalid addition", matrixA: NewF2(2, 2).Set([]*big.Int{big.NewInt(2), big.NewInt(1)}), matrixB: NewF2(2, 3).Set([]*big.Int{big.NewInt(2), big.NewInt(2)}), expectedMatrix: nil, @@ -37,6 +37,65 @@ func TestMatrixAdd(t *testing.T) { continue } - assert.Equal(t, true, test.expectedMatrix.IsEqual(test.matrixA)) + assert.True(t, test.matrixA.IsEqual(test.expectedMatrix)) + } +} + +func TestMulMatrix(t *testing.T) { + tests := []struct { + description string + matrixA *F2 + matrixB *F2 + expectedMatrix *F2 + }{ + { + description: "success", + matrixA: NewF2(2, 2).Set([]*big.Int{big.NewInt(2), big.NewInt(1)}), + matrixB: NewF2(2, 2).Set([]*big.Int{big.NewInt(1), big.NewInt(3)}), + expectedMatrix: NewF2(2, 2).Set([]*big.Int{big.NewInt(3), big.NewInt(1)}), + }, + { + description: "invalid multiplication", + matrixA: NewF2(2, 2).Set([]*big.Int{big.NewInt(2), big.NewInt(1)}), + matrixB: NewF2(2, 3).Set([]*big.Int{big.NewInt(1), big.NewInt(3)}), + expectedMatrix: nil, + }, + } + + for _, test := range tests { + test.matrixA.MulMatrix(test.matrixB) + + assert.IsType(t, test.expectedMatrix, test.matrixA) + + if test.expectedMatrix == nil { + continue + } + + assert.True(t, test.matrixA.IsEqual(test.expectedMatrix)) + } +} + +func TestAddBits(t *testing.T) { + tests := []struct { + description string + number *big.Int + expectedResult uint + }{ + { + description: "two bits 1", + number: big.NewInt(3), + expectedResult: 0, + }, + { + description: "one bit 1", + number: big.NewInt(4), + expectedResult: 1, + }, + } + + for _, test := range tests { + result := addBits(test.number) + + assert.Equal(t, test.expectedResult, result) } } diff --git a/matrix.go b/matrix.go index eb89fce..eb32b70 100644 --- a/matrix.go +++ b/matrix.go @@ -274,3 +274,31 @@ func (f *F2) PermuteCols() *F2 { // return the permutation matrix return permutationMatrix } + +// GetCol returns the column at index i +// +// This function returns the column as big.Int after the index is verified. +// If an invalid index is used, the function returns nil. +// +// @param int i The index for the column +// +// @return *big.Int +func (f *F2) GetCol(i int) *big.Int { + // check for input parameters + if i < 0 || i >= f.M { + // return nil + return nil + } + + // initialize the output big.Int + output := big.NewInt(0) + + // iterate through the rows + for j, row := range f.Rows { + // the the corresponding bit + output.SetBit(output, j, row.Bit(i)) + } + + // return the result + return output +} diff --git a/matrix_test.go b/matrix_test.go index c4b034d..85324bb 100644 --- a/matrix_test.go +++ b/matrix_test.go @@ -340,3 +340,37 @@ func TestPermuteCols(t *testing.T) { assert.NotNil(t, permMat) } } + +func TestGetCol(t *testing.T) { + tests := []struct { + description string + matrixA *F2 + colIndex int + expectedResult *big.Int + }{ + { + description: "first column of a 2x2 matrix", + matrixA: NewF2(2, 2).Set([]*big.Int{big.NewInt(2), big.NewInt(1)}), + colIndex: 0, + expectedResult: big.NewInt(2), + }, + { + description: "second column of a 2x2 matrix", + matrixA: NewF2(2, 2).Set([]*big.Int{big.NewInt(2), big.NewInt(1)}), + colIndex: 1, + expectedResult: big.NewInt(1), + }, + { + description: "invalid parameter", + matrixA: NewF2(2, 2).Set([]*big.Int{big.NewInt(2), big.NewInt(1)}), + colIndex: 2, + expectedResult: nil, + }, + } + + for _, test := range tests { + result := test.matrixA.GetCol(test.colIndex) + + assert.Equal(t, test.expectedResult, result) + } +} -- GitLab