summaryrefslogtreecommitdiffstats
path: root/Godeps/_workspace/src/code.google.com/p/graphics-go/graphics/detect/opencv_parser.go
blob: 51ded1a1c636a5011aeb7ff3ab988e161cf424ab (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
// Copyright 2011 The Graphics-Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package detect

import (
	"bytes"
	"encoding/xml"
	"errors"
	"fmt"
	"image"
	"io"
	"io/ioutil"
	"strconv"
	"strings"
)

type xmlFeature struct {
	Rects     []string `xml:"grp>feature>rects>grp"`
	Tilted    int      `xml:"grp>feature>tilted"`
	Threshold float64  `xml:"grp>threshold"`
	Left      float64  `xml:"grp>left_val"`
	Right     float64  `xml:"grp>right_val"`
}

type xmlStages struct {
	Trees           []xmlFeature `xml:"trees>grp"`
	Stage_threshold float64      `xml:"stage_threshold"`
	Parent          int          `xml:"parent"`
	Next            int          `xml:"next"`
}

type opencv_storage struct {
	Any struct {
		XMLName xml.Name
		Type    string      `xml:"type_id,attr"`
		Size    string      `xml:"size"`
		Stages  []xmlStages `xml:"stages>grp"`
	} `xml:",any"`
}

func buildFeature(r string) (f Feature, err error) {
	var x, y, w, h int
	var weight float64
	_, err = fmt.Sscanf(r, "%d %d %d %d %f", &x, &y, &w, &h, &weight)
	if err != nil {
		return
	}
	f.Rect = image.Rect(x, y, x+w, y+h)
	f.Weight = weight
	return
}

func buildCascade(s *opencv_storage) (c *Cascade, name string, err error) {
	if s.Any.Type != "opencv-haar-classifier" {
		err = fmt.Errorf("got %s want opencv-haar-classifier", s.Any.Type)
		return
	}
	name = s.Any.XMLName.Local

	c = &Cascade{}
	sizes := strings.Split(s.Any.Size, " ")
	w, err := strconv.Atoi(sizes[0])
	if err != nil {
		return nil, "", err
	}
	h, err := strconv.Atoi(sizes[1])
	if err != nil {
		return nil, "", err
	}
	c.Size = image.Pt(w, h)
	c.Stage = []CascadeStage{}

	for _, stage := range s.Any.Stages {
		cs := CascadeStage{
			Classifier: []Classifier{},
			Threshold:  stage.Stage_threshold,
		}
		for _, tree := range stage.Trees {
			if tree.Tilted != 0 {
				err = errors.New("Cascade does not support tilted features")
				return
			}

			cls := Classifier{
				Feature:   []Feature{},
				Threshold: tree.Threshold,
				Left:      tree.Left,
				Right:     tree.Right,
			}

			for _, rect := range tree.Rects {
				f, err := buildFeature(rect)
				if err != nil {
					return nil, "", err
				}
				cls.Feature = append(cls.Feature, f)
			}

			cs.Classifier = append(cs.Classifier, cls)
		}
		c.Stage = append(c.Stage, cs)
	}

	return
}

// ParseOpenCV produces a detection Cascade from an OpenCV XML file.
func ParseOpenCV(r io.Reader) (cascade *Cascade, name string, err error) {
	// BUG(crawshaw): tag-based parsing doesn't seem to work with <_>
	buf, err := ioutil.ReadAll(r)
	if err != nil {
		return
	}
	buf = bytes.Replace(buf, []byte("<_>"), []byte("<grp>"), -1)
	buf = bytes.Replace(buf, []byte("</_>"), []byte("</grp>"), -1)

	s := &opencv_storage{}
	err = xml.Unmarshal(buf, s)
	if err != nil {
		return
	}
	return buildCascade(s)
}