diff --git a/matrix.go b/matrix.go index a76ba751017518ffaf42ee0ce89ef4b0a943d19d..a4241aa584811a6a89f54c435958199d7d04f0d3 100644 --- a/matrix.go +++ b/matrix.go @@ -198,7 +198,10 @@ func (f *F2) SetToIdentity() { // // @param int i The index of the first row to swap // @param int j The index of the second row to swap +// +// @return error func (f *F2) SwapRows(i, j int) error { + // check for input parameters if i >= f.N || j >= f.N || i < 0 || j < 0 { return fmt.Errorf("Index does not exist") } @@ -208,3 +211,30 @@ func (f *F2) SwapRows(i, j int) error { // return success return nil } + +// SwapCols swaps the columns at index i with the row at index j +// +// @param int i The index of the first columns to swap +// @param int j The index of the second columns to swap +// +// @return error +func (f *F2) SwapCols(i, j int) error { + // check for input parameters + if i >= f.M || j >= f.M || i < 0 || j < 0 { + return fmt.Errorf("Index does not exist") + } + + // iterate through the rows + for _, row := range f.Rows { + // get the bit with the given index + bitI := row.Bit(i) + bitJ := row.Bit(j) + + // set the swapped bits + row.SetBit(row, i, bitJ) + row.SetBit(row, j, bitI) + } + + // return success + return nil +} diff --git a/matrix_test.go b/matrix_test.go index 22b18816500eb968927c051acadeb1d583d35472..c0cfe9eda6e7952222904056ad5baaacd412c09a 100644 --- a/matrix_test.go +++ b/matrix_test.go @@ -240,7 +240,7 @@ func TestSwapRows(t *testing.T) { matrixA: NewF2(2, 2).Set([]*big.Int{big.NewInt(3), big.NewInt(1)}), i: 2, j: 0, - expectedMatrix: NewF2(2, 2).Set([]*big.Int{big.NewInt(1), big.NewInt(3)}), + expectedMatrix: NewF2(2, 2), expectedError: true, }, } @@ -257,3 +257,43 @@ func TestSwapRows(t *testing.T) { assert.Equal(t, true, test.matrixA.IsEqual(test.expectedMatrix)) } } + +func TestSwapCols(t *testing.T) { + tests := []struct { + description string + matrixA *F2 + i int + j int + expectedMatrix *F2 + expectedError bool + }{ + { + description: "3x3 matrix", + matrixA: NewF2(3, 3).Set([]*big.Int{big.NewInt(3), big.NewInt(1), big.NewInt(5)}), + i: 2, + j: 0, + expectedMatrix: NewF2(3, 3).Set([]*big.Int{big.NewInt(6), big.NewInt(4), big.NewInt(5)}), + expectedError: false, + }, + { + description: "invalid index", + matrixA: NewF2(2, 2).Set([]*big.Int{big.NewInt(3), big.NewInt(1)}), + i: 2, + j: 0, + expectedMatrix: NewF2(2, 2), + expectedError: true, + }, + } + + for _, test := range tests { + err := test.matrixA.SwapCols(test.i, test.j) + + assert.Equal(t, test.expectedError, err != nil) + + if err != nil { + continue + } + + assert.Equal(t, true, test.matrixA.IsEqual(test.expectedMatrix)) + } +}