1package ndr
2
3import (
4	"bytes"
5	"encoding/hex"
6	"testing"
7
8	"github.com/stretchr/testify/assert"
9)
10
11func TestReadCommonHeader(t *testing.T) {
12	var tests = []struct {
13		EncodedHex string
14		ExpectFail bool
15	}{
16		{"01100800cccccccc", false}, // Little Endian
17		{"01000008cccccccc", false}, // Big Endian have to change the bytes for the header size? This test vector was artificially created. Need proper test vector
18		//{"01100800cccccccc1802000000000000", false},
19		//{"01100800cccccccc0002000000000000", false},
20		//{"01100800cccccccc0001000000000000", false},
21		//{"01100800cccccccce000000000000000", false},
22		//{"01100800ccccccccf000000000000000", false},
23		//{"01100800cccccccc7801000000000000", false},
24		//{"01100800cccccccc4801000000000000", false},
25		//{"01100800ccccccccd001000000000000", false},
26		{"02100800cccccccc", true}, // Incorrect version
27		{"02100900cccccccc", true}, // Incorrect length
28
29	}
30
31	for i, test := range tests {
32		b, _ := hex.DecodeString(test.EncodedHex)
33		dec := NewDecoder(bytes.NewReader(b))
34		err := dec.readCommonHeader()
35		if err != nil && !test.ExpectFail {
36			t.Errorf("error reading common header of test %d: %v", i, err)
37		}
38		if err == nil && test.ExpectFail {
39			t.Errorf("expected failure on reading common header of test %d: %v", i, err)
40		}
41	}
42}
43
44func TestReadPrivateHeader(t *testing.T) {
45	var tests = []struct {
46		EncodedHex string
47		ExpectFail bool
48		Length     int
49	}{
50		{"01100800cccccccc1802000000000000", false, 536},
51		{"01100800cccccccc0002000000000000", false, 512},
52		{"01100800cccccccc0001000000000000", false, 256},
53		{"01100800ccccccccFF00000000000000", true, 255}, // Length not multiple of 8
54		{"01100800cccccccc00010000000000", true, 256},   // Too short
55
56	}
57
58	for i, test := range tests {
59		b, _ := hex.DecodeString(test.EncodedHex)
60		dec := NewDecoder(bytes.NewReader(b))
61		err := dec.readCommonHeader()
62		if err != nil {
63			t.Errorf("error reading common header of test %d: %v", i, err)
64		}
65		err = dec.readPrivateHeader()
66		if err != nil && !test.ExpectFail {
67			t.Errorf("error reading private header of test %d: %v", i, err)
68		}
69		if err == nil && test.ExpectFail {
70			t.Errorf("expected failure on reading private header of test %d: %v", i, err)
71		}
72		if dec.ph.ObjectBufferLength != uint32(test.Length) {
73			t.Errorf("Objectbuffer length expected %d actual %d", test.Length, dec.ph.ObjectBufferLength)
74		}
75	}
76}
77
78type SimpleTest struct {
79	A uint32
80	B uint32
81}
82
83func TestBasicDecode(t *testing.T) {
84	hexStr := "01100800cccccccca00400000000000000000200d186660f656ac601"
85	b, _ := hex.DecodeString(hexStr)
86	ft := new(SimpleTest)
87	dec := NewDecoder(bytes.NewReader(b))
88	err := dec.Decode(ft)
89	if err != nil {
90		t.Fatalf("error decoding: %v", err)
91	}
92	assert.Equal(t, uint32(258377425), ft.A, "Value of field A not as expected")
93	assert.Equal(t, uint32(29780581), ft.B, "Value of field B not as expected %d")
94}
95
96func TestBasicDecodeOverRun(t *testing.T) {
97	hexStr := "01100800cccccccca00400000000000000000200d186660f"
98	b, _ := hex.DecodeString(hexStr)
99	ft := new(SimpleTest)
100	dec := NewDecoder(bytes.NewReader(b))
101	err := dec.Decode(ft)
102	if err == nil {
103		t.Errorf("Expected error for trying to read more than the bytes we have")
104	}
105}
106
107type testEmbeddingPointer struct {
108	A testEmbeddedPointer `ndr:"pointer"`
109	B uint32              // 1
110}
111
112type testEmbeddedPointer struct {
113	C testEmbeddedPointer2 `ndr:"pointer"`
114	D uint32               `ndr:"pointer"` // 2
115	E uint32               // 3
116}
117
118type testEmbeddedPointer2 struct {
119	F uint32 `ndr:"pointer"` // 4
120	G uint32 // 5
121}
122
123func Test_EmbeddedPointers(t *testing.T) {
124	hexStr := TestHeader + "00040002" + "01000000" + "00040002" + "00040002" + "03000000" + "00040002" + "05000000" + "04000000" + "02000000"
125	b, _ := hex.DecodeString(hexStr)
126	ft := new(testEmbeddingPointer)
127	dec := NewDecoder(bytes.NewReader(b))
128	err := dec.Decode(ft)
129	if err != nil {
130		t.Fatalf("error decoding: %v", err)
131	}
132	assert.Equal(t, uint32(1), ft.B)
133	assert.Equal(t, uint32(2), ft.A.D)
134	assert.Equal(t, uint32(3), ft.A.E)
135	assert.Equal(t, uint32(4), ft.A.C.F)
136	assert.Equal(t, uint32(5), ft.A.C.G)
137}
138