diff --git a/gaussian.go b/gaussian.go index c74798d4aea2bbdccb9b416f1f16853b044f5718..caf9c434f867db174ea028c7143d48cd5a455b98 100644 --- a/gaussian.go +++ b/gaussian.go @@ -35,6 +35,8 @@ func (f *F2) GaussianElimination() { f.Rows[rr].Xor(f.Rows[rr], f.Rows[pivotBit]) } } + + //TODO: detect linear dependency } // do the same thing backwards to get the identity matrix @@ -126,6 +128,88 @@ func (f *F2) partialDiagonalize(startRow, startCol, stopRow, stopCol int) { } } +// PartialGaussianWithLinearChecking performs a partial gaussian elimination +// +// This function performs a gaussian elimination on the matrix and calls the +// check callback after each iteration in order to verify that linear +// dependencies in the code could be resolved easily. +func (f *F2) PartialGaussianWithLinearChecking( + startRow int, + startCol int, + stopRow int, + stopCol int, + linearCheck func(*F2, int, int, int, int, int) error, +) error { + // iterate through all possible pivot bits + for pivotBit := startCol; pivotBit <= stopCol; pivotBit++ { + // intialize the pivotbit indicator + foundPivotBit := false + + // iterate through the rows + for rowCounter := startRow + pivotBit - startCol; rowCounter <= stopRow; rowCounter++ { + // if the pivotbit of this row is 0... + if f.Rows[rowCounter].Bit(pivotBit) == uint(0) { + // ...check the next row + continue + } + + // if the row with a valid pivot bit is not the first row... + if pivotBit-startCol != rowCounter { + // ...swap it with first one + f.SwapRows(pivotBit-startCol, rowCounter) + } + + // iterate through all other rows except the first one + for rr := startRow + pivotBit - startCol + 1; rr <= stopRow; rr++ { + if f.Rows[rr].Bit(pivotBit) == uint(0) { + continue + } + + // subtract the 1 from all other rows with the pivotBit + f.Rows[rr].Xor( + f.Rows[rr], + f.Rows[startRow+pivotBit-startCol], + ) + } + + // indicate the pivotbit is found + foundPivotBit = true + + // break out of the loop + break + } + + // if a pivot bit was found... + if foundPivotBit { + // ...skip to the next row + continue + } + + // detect linear dependencies and try to resolve them + err := linearCheck( + f, + startRow, + startCol, + stopRow, + stopCol, + pivotBit, + ) + + // check the error + if err != nil { + return err + } + + // process the same row again + pivotBit-- + } + + // do the same thing backwards to get the identity matrix + f.partialDiagonalize(startRow, startCol, stopRow, stopCol) + + return nil +} + // CheckGaussian checks if the given range in the matrix is the identity matrix // // @param int startRow The row where the check starts diff --git a/gaussian_test.go b/gaussian_test.go index 6a83b3c0e6b1183e05f6a6192fb353aac8136cc6..3aca9bc796ee343555460e2892eafcdf579af8f3 100644 --- a/gaussian_test.go +++ b/gaussian_test.go @@ -2,6 +2,7 @@ package gomatrix import ( + "fmt" "math/big" "testing" @@ -103,12 +104,166 @@ func TestPartialGaussianElimination(t *testing.T) { test.stopCol, ) - test.matrixA.PrettyPrint() - assert.True(t, test.matrixA.IsEqual(test.expectedMatrix)) } } +func TestPartialGaussianWithLinearChecking(t *testing.T) { + tests := []struct { + description string + matrix *F2 + startRow int + startCol int + stopRow int + stopCol int + linearCheck func(*F2, int, int, int, int, int) error + expectedResult *F2 + expectedError bool + }{ + { + description: "4x4 with one swap", + matrix: NewF2(4, 4).Set([]*big.Int{ + big.NewInt(10), + big.NewInt(13), + big.NewInt(12), + big.NewInt(14), + }), + startRow: 0, + startCol: 1, + stopRow: 2, + stopCol: 3, + linearCheck: func(f *F2, startRow, startCol, stopRow, stopCol, pivotBit int) error { + // create a bitmask for the row check + bitmask := big.NewInt(0).SetBit(big.NewInt(0), stopCol-startCol+1, 1) + bitmask = bitmask.Sub(bitmask, big.NewInt(1)) + bitmask = bitmask.Lsh(bitmask, uint(startCol)) + + foundValidRow := false + + // iterate through the rows + for index, row := range f.Rows { + // skip all rows, that are processed by the gaussian elimination + if index >= startRow && index <= stopRow { + continue + } + + // get the bits to check + bitsToCheck := big.NewInt(0).And( + bitmask, + row, + ) + + // if the bits are 0... + if bitsToCheck.Cmp(big.NewInt(0)) == 0 { + // ...skip the row + continue + } + + // swap the rows + f.SwapRows(pivotBit-1, index) + + foundValidRow = true + + // exit the loop + break + } + + if !foundValidRow { + return fmt.Errorf("cannot resolv linear dependency") + } + + for i := startCol; i < pivotBit; i++ { + if f.Rows[pivotBit-startCol].Bit(i) == uint(0) { + continue + } + + fmt.Printf("%d xor %d\n", pivotBit-startCol, startRow+i-startCol) + + f.Rows[pivotBit-startCol].Xor( + f.Rows[pivotBit-startCol], + f.Rows[startRow+i-startCol], + ) + } + + return nil + }, + expectedResult: NewF2(4, 4).Set([]*big.Int{ + big.NewInt(3), + big.NewInt(4), + big.NewInt(9), + big.NewInt(1), + }), + expectedError: false, + }, + { + description: "4x4 with error", + matrix: NewF2(4, 4).Set([]*big.Int{ + big.NewInt(10), + big.NewInt(13), + big.NewInt(12), + big.NewInt(14), + }), + startRow: 0, + startCol: 1, + stopRow: 2, + stopCol: 3, + linearCheck: func(f *F2, startRow, startCol, stopRow, stopCol, pivotBit int) error { + return fmt.Errorf("testfoo") + }, + expectedResult: NewF2(4, 4).Set([]*big.Int{ + big.NewInt(3), + big.NewInt(4), + big.NewInt(9), + big.NewInt(1), + }), + expectedError: true, + }, + { + description: "4x4 with error", + matrix: NewF2(4, 4).Set([]*big.Int{ + big.NewInt(13), + big.NewInt(10), + big.NewInt(12), + big.NewInt(14), + }), + startRow: 0, + startCol: 1, + stopRow: 2, + stopCol: 3, + linearCheck: func(f *F2, startRow, startCol, stopRow, stopCol, pivotBit int) error { + return fmt.Errorf("testfoo") + }, + expectedResult: NewF2(4, 4).Set([]*big.Int{ + big.NewInt(3), + big.NewInt(4), + big.NewInt(9), + big.NewInt(1), + }), + expectedError: true, + }, + } + + for _, test := range tests { + err := test.matrix.PartialGaussianWithLinearChecking( + test.startRow, + test.startCol, + test.stopRow, + test.stopCol, + test.linearCheck, + ) + + test.matrix.PrettyPrint() + + assert.Equalf(t, test.expectedError, err != nil, test.description) + + if err != nil { + continue + } + + assert.Truef(t, test.expectedResult.IsEqual(test.matrix), test.description) + } +} + func TestCheckGaussian(t *testing.T) { tests := []struct { description string diff --git a/resolver/lineardependency.go b/resolver/lineardependency.go new file mode 100644 index 0000000000000000000000000000000000000000..96e6397431639f473153930830e19804180b3e38 --- /dev/null +++ b/resolver/lineardependency.go @@ -0,0 +1,76 @@ +package resolver + +import ( + "fmt" + "math/big" + + "git.noc.ruhr-uni-bochum.de/danieljankowski/gomatrix" +) + +// LinearDependenciesInGauss tries to resolve linear dependencies in the gaussian +// elimination. +// +// This function is used in order to try to resolve linear dependencies while +// using the function PartialGaussianWithLinearChecking as linearCheck-function. +func LinearDependenciesInGauss( + f *gomatrix.F2, + startRow int, + startCol int, + stopRow int, + stopCol int, + pivotBit int, +) error { + // create a bitmask for the row check + bitmask := big.NewInt(0).SetBit(big.NewInt(0), stopCol-startCol+1, 1) + bitmask = bitmask.Sub(bitmask, big.NewInt(1)) + bitmask = bitmask.Lsh(bitmask, uint(startCol)) + + foundValidRow := false + + // iterate through the rows + for index, row := range f.Rows { + // skip all rows, that are processed by the gaussian elimination + if index >= startRow && index <= stopRow { + continue + } + + // get the bits to check + bitsToCheck := big.NewInt(0).And( + bitmask, + row, + ) + + // if the bits are 0... + if bitsToCheck.Cmp(big.NewInt(0)) == 0 { + // ...skip the row + continue + } + + // swap the rows + f.SwapRows(pivotBit-1, index) + + foundValidRow = true + + // exit the loop + break + } + + if !foundValidRow { + return fmt.Errorf("cannot resolve linear dependency") + } + + for i := startCol; i < pivotBit; i++ { + if f.Rows[pivotBit-startCol].Bit(i) == uint(0) { + continue + } + + fmt.Printf("%d xor %d\n", pivotBit-startCol, startRow+i-startCol) + + f.Rows[pivotBit-startCol].Xor( + f.Rows[pivotBit-startCol], + f.Rows[startRow+i-startCol], + ) + } + + return nil +} diff --git a/resolver/lineardependency_test.go b/resolver/lineardependency_test.go new file mode 100644 index 0000000000000000000000000000000000000000..5d03ebc659987dcd098c4606e4bbc0d6ccb9b248 --- /dev/null +++ b/resolver/lineardependency_test.go @@ -0,0 +1,107 @@ +package resolver + +import ( + "math/big" + "testing" + + "git.noc.ruhr-uni-bochum.de/danieljankowski/gomatrix" + + "github.com/stretchr/testify/assert" +) + +func TestLinearDependenciesInGauss(t *testing.T) { + tests := []struct { + description string + matrix *gomatrix.F2 + startRow int + startCol int + stopRow int + stopCol int + pivotBit int + expectedError bool + expectedResult *gomatrix.F2 + }{ + { + description: "simple swap and postprocessing of the row", + matrix: gomatrix.NewF2(4, 4).Set([]*big.Int{ + big.NewInt(10), + big.NewInt(13), + big.NewInt(1), + big.NewInt(14), + }), + startRow: 0, + startCol: 1, + stopRow: 2, + stopCol: 3, + pivotBit: 3, + expectedError: false, + expectedResult: gomatrix.NewF2(4, 4).Set([]*big.Int{ + big.NewInt(10), + big.NewInt(13), + big.NewInt(9), + big.NewInt(1), + }), + }, + { + description: "no row to swap with", + matrix: gomatrix.NewF2(4, 4).Set([]*big.Int{ + big.NewInt(10), + big.NewInt(13), + big.NewInt(1), + big.NewInt(1), + }), + startRow: 0, + startCol: 1, + stopRow: 2, + stopCol: 3, + pivotBit: 3, + expectedError: true, + expectedResult: gomatrix.NewF2(4, 4).Set([]*big.Int{ + big.NewInt(10), + big.NewInt(13), + big.NewInt(9), + big.NewInt(1), + }), + }, + { + description: "simple swap and postprocessing of the row + continue", + matrix: gomatrix.NewF2(4, 4).Set([]*big.Int{ + big.NewInt(10), + big.NewInt(13), + big.NewInt(1), + big.NewInt(10), + }), + startRow: 0, + startCol: 1, + stopRow: 2, + stopCol: 3, + pivotBit: 3, + expectedError: false, + expectedResult: gomatrix.NewF2(4, 4).Set([]*big.Int{ + big.NewInt(10), + big.NewInt(13), + big.NewInt(0), + big.NewInt(1), + }), + }, + } + + for _, test := range tests { + err := LinearDependenciesInGauss( + test.matrix, + test.startRow, + test.startCol, + test.stopRow, + test.stopCol, + test.pivotBit, + ) + + assert.Equalf(t, test.expectedError, err != nil, test.description) + + if err != nil { + continue + } + + assert.Truef(t, test.expectedResult.IsEqual(test.matrix), test.description) + } +}