CGO函数调用与数据转换

CGO是Go程序调用C库的一套机制,可以使Go语言能够站在C/C++的肩膀上。比如Go调用Tensorflow,也就是使用了CGO来实现的。很多语言都支持对C库的调用,一般称为FFI(Foreign Function Interface)。这里需要注意一下,一般的Python、Lua等语言调用C函数的时候,是通过函数签名找到函数的地址,然后直接调用对应的函数。而CGO会先生成中间文件,然后再一起编译调用。这里不具体展开了,总之只需要知道CGO比一般的FFI多了一层中间的步骤。在程序出错的时候,能理解调用栈的关系即可。参考 CGO 和 CGO 性能之谜

函数调用

package main

/*
CGO的标准写法:
1. 先用注释的方式写入,单行注释和多行注释都支持
  1.1 编译器环境变量
	1.1.1 CFLAGS: C编译选项
	1.1.2 CXXFLAGS: C++编译选项
	1.1.3 CPPFLAGS: C和C++共有的编译选项
	1.1.4 FFLAGS: Fortran编译选项
	1.1.4 LDFLAGS: 链接选项(不区分C和C++)
  1.2 C代码
2. import "C",相当于将所有的C函数放入虚拟的package C。之后通过`C.xxx`的方式来调用。需要紧跟注释之后。
*/

/*
#cgo LDFLAGS: -lm

#include <math.h>
double my_sqrt(double x) {
	return sqrt(x);
}
*/
import "C" // CGO的标准写法,相当于将C函数放入包`C`中
import "fmt"

func main() {
	a := 100
	fmt.Printf("sqrt(%v) = %v\n", a, C.my_sqrt(C.double(a))) // 调用上面自定义的函数
	fmt.Printf("sqrt(%v) = %v\n", a, C.sqrt(C.double(a)))    // 调用系统库函数(上述的m)
}

// Output:
//  sqrt(100) = 10
//  sqrt(100) = 10

通过上述的例子,我们可以看出,CGO可以调用C的库函数,也可以执行注释中的C代码的函数。一般注释中的C代码都是比较简短的。

接下来要介绍C和Go之间的数据是如何传递的。

数据类型

我们可以将数据类型分为以下几类: 1. 数值:两种语言均有自己的定义关键字 1. 整型: 8/16/32/64位有符号和无符号整数 2. 浮点型:32/64位浮点数 2. 字符串:C中的字符串是以\0结尾的字符数组,而Go中使用string 3. 结构体:包含0到多个字段的自定义结构。在两种语言中均可以用struct关键字来定义,二者十分相似。 4. 指针:本质上其实是一个32或64位的整数。用来指向内存的数据或函数的地址。 5. 数组:连续存放相同类型数据的一种结构。在C中一般理解为一块连续内存,在Go中,有Array和Slice两种形式,后续会介绍到。

接下来我们分别介绍不同的数据类型是如何在两种语言间传递和转换的。

基础数据类型

C stdint GO C/stdint -> GO Go -> C Go -> stdint
int8 signed char int8_t int8 int8 C.schar C.int8_t
int16 short int16_t int16 int16 C.short C.int16_t
int32 int int32_t int32 int32 C.int C.int32_t
int64 long long int64_t int64 int64 C.longlong C.int64_t
uint8 unsigned char uint8_t uint8 uint8 C.uchar C.uint8_t
uint16 unsigned short uint16_t uint16 uint16 C.ushort C.uint16_t
uint32 unsigned int uint32_t uint32 uint32 C.uint C.uint32_t
uint64 unsigned long long uint64_t uint64 uint64 C.ulonglong C.uint64_t
float32 float float32 float32 C.float
float64 double float64 float64 C.double
pointer DType * unsafe.Pointer unsafe.Pointer * C.DType
char char byte byte C.char
string char * string C.GoString C.CString

数值

按照字节的数目,整型一般占用1/2/4/8字节,另外包含有无符号,整型就有8种类型。对应的就是int8/int16/int32/int64四种有符号整型和uint8/uint16/uint32/uint64无符号整型。

在C中,通过char/short/int/long long等表示不同长度的有符号整型。如果include <stdint.h>,则可以用int8/int16/int32/int64等来表示。

浮点数则按照4/8字节,定义为float32/float64。

以下是一个例子,比较简单,就不具体介绍了。

package main

/*
#include <stdint.h>
*/
import "C"
import "fmt"

func main() {
	data := []float64{-1000, -50, -10, -5, -1, 0, 1, 5, 10, 50, 100, 10000, 3.14, -3.14, 1.0, -1.0, 0.5, -0.5}

	for _, d := range data {
		fmt.Printf("test number %v\n", d)

		// go
		int8_v := int8(d)
		int16_v := int16(d)
		int32_v := int32(d)
		int64_v := int64(d)
		uint8_v := uint8(d)
		uint16_v := uint16(d)
		uint32_v := uint32(d)
		uint64_v := uint64(d)
		float32_v := float32(d)
		float64_v := float64(d)

		// go -> c
		c_int8_v := C.schar(int8_v)
		c_int16_v := C.short(int16_v)
		c_int32_v := C.int(int32_v)
		c_int64_v := C.longlong(int64_v)
		c_uint8_v := C.uchar(uint8_v)
		c_uint16_v := C.ushort(uint16_v)
		c_uint32_v := C.uint(uint32_v)
		c_uint64_v := C.ulonglong(uint64_v)
		c_float32_v := C.float(float32_v)
		c_float64_v := C.double(float64_v)

		// go -> c stdint
		stdint_int8_v := C.int8_t(int8_v)
		stdint_int16_v := C.int16_t(int16_v)
		stdint_int32_v := C.int32_t(int32_v)
		stdint_int64_v := C.int64_t(int64_v)
		stdint_uint8_v := C.uint8_t(uint8_v)
		stdint_uint16_v := C.uint16_t(uint16_v)
		stdint_uint32_v := C.uint32_t(uint32_v)
		stdint_uint64_v := C.uint64_t(uint64_v)

		// c -> go
		go_int8_v := int8(c_int8_v)
		go_int16_v := int16(c_int16_v)
		go_int32_v := int32(c_int32_v)
		go_int64_v := int64(c_int64_v)
		go_uint8_v := uint8(c_uint8_v)
		go_uint16_v := uint16(c_uint16_v)
		go_uint32_v := uint32(c_uint32_v)
		go_uint64_v := uint64(c_uint64_v)
		go_float32_v := float32(c_float32_v)
		go_float64_v := float64(c_float64_v)

		fmt.Printf("int8    | raw %-10v | c %-10v | stdint %-10v | c->go %-10v\n", int8_v, c_int8_v, stdint_int8_v, go_int8_v)
		fmt.Printf("int16   | raw %-10v | c %-10v | stdint %-10v | c->go %-10v\n", int16_v, c_int16_v, stdint_int16_v, go_int16_v)
		fmt.Printf("int32   | raw %-10v | c %-10v | stdint %-10v | c->go %-10v\n", int32_v, c_int32_v, stdint_int32_v, go_int32_v)
		fmt.Printf("int64   | raw %-10v | c %-10v | stdint %-10v | c->go %-10v\n", int64_v, c_int64_v, stdint_int64_v, go_int64_v)
		fmt.Printf("uint8   | raw %-10v | c %-10v | stdint %-10v | c->go %-10v\n", uint8_v, c_uint8_v, stdint_uint8_v, go_uint8_v)
		fmt.Printf("uint16  | raw %-10v | c %-10v | stdint %-10v | c->go %-10v\n", uint16_v, c_uint16_v, stdint_uint16_v, go_uint16_v)
		fmt.Printf("uint32  | raw %-10v | c %-10v | stdint %-10v | c->go %-10v\n", uint32_v, c_uint32_v, stdint_uint32_v, go_uint32_v)
		fmt.Printf("uint64  | raw %-10v | c %-10v | stdint %-10v | c->go %-10v\n", uint64_v, c_uint64_v, stdint_uint64_v, go_uint64_v)
		fmt.Printf("float32 | raw %-10v | c %-10v | c->go %-10v\n", float32_v, c_float32_v, go_float32_v)
		fmt.Printf("float64 | raw %-10v | c %-10v | c->go %-10v\n", float64_v, c_float64_v, go_float64_v)
	}
}

字符串

在C中,char表示单个字符(不考虑宽字符),而以\0结尾的字符数组就表示字符串。而Go中,单个字节用byte来表示,字符串用string来表示。因此Go中的byte/[]byte/string是三个不同的概念(string其实是包含lendata的一个结构体)。

CGO中,使用C.CString可以将Go string转换为C string,这里的C string是深拷贝的,因此需要使用完之后调用free来释放内存。使用C.GoString可以将C string转换为Go String,这里也是深拷贝的,之后的内存由Go自己管理。

深拷贝通常意味着需要内存的申请和复制,在极度在意性能的场景,我们需要用一些技巧来避免它。这部分后续会介绍到。

package main

/*
#include <stdio.h>
#include <stdlib.h>
void print_string(const char *s) {
	printf("%s\n", s);
}
*/
import "C"
import (
	"fmt"
	"unsafe"
)

func main() {
	str := "Hello World!"
	// byte
	{
		b := str[0]
		c_char := C.char(b)
		go_byte := byte(c_char)
		fmt.Printf("%v %v %v\n", b, c_char, go_byte)
	}

	// string
	{
		c_str := C.CString(str)     // Go -> C
		go_str := C.GoString(c_str) // C -> Go

		fmt.Printf("%v %v %v\n", str, c_str, go_str)
		C.print_string(c_str)

		C.free(unsafe.Pointer(c_str))
	}
}

指针

在C和Go中都有指针的概念,使用上也是相同的。

在两种指针做转换的时候,一般需要先转到中间的状态 unsafe.Pointer,以绕开编译检查。

下面的例子是分别用C和Go编写add函数,完成两个int的求和。由于C的int和和Go的int32类型的内存布局一致,都是4字节(Go的int是随机器位数变化的),因此C的add可以将结果直接写入Go的变量,反之亦可。

package main

/*
void add(int a, int b, int *c) {
        *c = a + b;
}
*/
import "C"
import (
	"fmt"
	"unsafe"
)

func add(a, b int32, c *int32) {
	*c = a + b
}

func main() {
	var a, b int32
	var go_sum int32
	var c_sum C.int

	// C add
	{
		a = 100
		b = 200
		C.add(C.int(a), C.int(b), &c_sum)
		fmt.Printf("%v + %v = %v\n", a, b, c_sum)

		C.add(C.int(a), C.int(b), (*C.int)(unsafe.Pointer(&go_sum)))
		fmt.Printf("%v + %v = %v\n", a, b, go_sum)
	}

	// Go add
	{
		a = 300
		b = 400
		add(a, b, &go_sum)
		fmt.Printf("%v + %v = %v\n", a, b, go_sum)

		add(a, b, (*int32)(unsafe.Pointer(&c_sum)))
		fmt.Printf("%v + %v = %v\n", a, b, c_sum)
	}
}

下面是两种调用下的数据转换的流程:

数据 地址 unsafe.Pointer 强转
C Add 变量 go_sum &go_sum unsafe.Pointer(&go_sum) (*C.int)(unsafe.Pointer(&go_sum))
C Add 类型 int *int unsafe.Pointer *C.int
Go Add 变量 c_sum &c_sum unsafe.Pointer(&c_sum) (*int)(unsafe.Pointer(&c_sum))
Go Add 类型 C.int *C.int unsafe.Pointer *int

结构体

C和Go中的结构体基本一致。CGO会将C中声明的结构体加上struct_前缀,用于区分。在使用上和Go struct的基本相同。

package main

/*
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
typedef struct CStruct {
	int int_val;
	float float_val;
	char *str_val;
} CStruct;

CStruct *new_struct() {
	CStruct *c = (CStruct *)malloc(sizeof(CStruct));
	memset(c, 0, sizeof(CStruct));
	return c;
}

void print_struct(CStruct* c) {
	printf("int: %d float: %f, str: %s\n", c->int_val, c->float_val, c->str_val);
}

void free_struct(CStruct* c) {
	if (!c) return;
	if (c->str_val) {
		free(c->str_val);
	}
	free(c);
}
*/
import "C"
import "unsafe"

func main() {
	{
		c := C.new_struct()
		C.print_struct(c)
		C.free_struct(c)
	}

	{
		c := C.new_struct()
		c.int_val = C.int(100)
		c.float_val = C.float(0.618)
		c.str_val = C.CString("hello")
		C.print_struct(c)
		C.free_struct(c)
	}

	{
		c := C.struct_CStruct{} // or c := &C.struct_CStruct{}
		c.int_val = C.int(100)
		c.float_val = C.float(0.618)
		c.str_val = C.CString("world")
		C.print_struct(&c)
		C.free(unsafe.Pointer(c.str_val))
		// C.free_struct(&c) // crash
	}
}

// Output:
// int: 0 float: 0.000000, str: (null)
// int: 100 float: 0.618000, str: hello
// int: 100 float: 0.618000, str: world

这里有两个地方需要注意:

  1. Go中结构体取值和指针取值都是使用 . 操作符。
  2. 第三个例子,CStruct对象是在Go中创建的,因此不能调用C的free_struct接口。

数组

我们知道,在C语言中,数组和指针大多数情况下是等价的,一般接口返回的也是指针类型。我们可以用 arr[i] 或者 *(arr + i)的方式访问数据中的某一项。

在Go中,有Array和Slice两种类型。其中Array的长度是固定的,而Slice可变。一般使用上更多的会使用Slice。

package main

/*
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
int* create_array(int size) {
	if (size <= 0) return NULL;
	int *arr = (int *)malloc(size * sizeof(int));
	memset(arr, 0, size * sizeof(int));
	return arr;
}
void fill_array(int *arr, int size) {
	for (int idx = 0; idx < size; ++ idx) {
		arr[idx] = idx;
	}
}
void print_array(int *arr, int size) {
	for (int idx = 0; idx < size; ++ idx) {
		printf("%d ", arr[idx]);
	}
	printf("\n");
}
void del_array(int *arr) {
	if (arr) {
		free(arr);
	}
}
*/
import "C"
import (
	"fmt"
	"unsafe"
)

func main() {
	{
		fmt.Println("Read C array")
		size := 5
		cSize := C.int(5)
		arr := C.create_array(cSize) // [0 0 0 0 0]
		C.print_array(arr, cSize)
		C.fill_array(arr, cSize) // [0 1 2 3 4]
		C.print_array(arr, cSize)

		{
			fmt.Println("Pointer + Offset")
			p := uintptr(unsafe.Pointer(arr))
			for idx := 0; idx < size; idx++ {
				fmt.Printf("idx %v val %v\n", idx, *(*C.int)(unsafe.Pointer(p)))
				p += uintptr(C.sizeof_int)
			}
		}

		{
			fmt.Println("Map C Array to Go Slice")
			int_slice := (*[1 << 31]C.int)(unsafe.Pointer(arr))[0:5:5]
			for idx, v := range int_slice {
				fmt.Printf("idx %v val %v\n", idx, v)
			}
		}
		C.del_array(arr)
	}

	{
		fmt.Println("Passthrough Go Array/Slice")
		int_arr := [5]C.int{5, 4, 3, 2, 1}
		C.print_array(&int_arr[0], 5)

		int_slice := int_arr[0:5:5]
		C.print_array(&int_slice[0], 5)
	}
}

// Output
// Read C array
// 0 0 0 0 0
// 0 1 2 3 4
// Pointer + Offset
// idx 0 val 0
// idx 1 val 1
// idx 2 val 2
// idx 3 val 3
// idx 4 val 4
// Map C Array to Go Slice
// idx 0 val 0
// idx 1 val 1
// idx 2 val 2
// idx 3 val 3
// idx 4 val 4
// Passthrough Go Array/Slice
// 5 4 3 2 1
// 5 4 3 2 1

例子中的第一部分是,Go读取C的数组。第二部分是C读取Go的数组。

Go读取C Array有两个方式:1. 裸操作指针, 2. 映射成Go Slice再操作。这里推荐第二种。

Pointer + Offset

C的数组本质上只是一个指针。我们只需要计算好每个元素的地址,就可以读写对应的元素。

Go有两个指针类型unsafe.Pointeruintptr。二者可以互相转换,但是只有uintptr才支持指针的计算。因此,我们先将数组的首地址转成uintptr,再加上偏移量,最后强转回C指针。

  1. 首地址: p := uintptr(unsafe.Pointer(arr))
  2. 强转回C指针并取值:*(*C.int)(unsafe.Pointer(p))
  3. 移动到下一个元素的位置:p += uintptr(C.sizeof_int),其中C.sizeof_XXX 可以获取 XXX 的sizeof的大小。
  4. 重复 2-3,直到数组结束

Map C Array to Go Slice

上述方法其实用起来还是有点麻烦的,需要自己控制指针,也容易出错。

实际上,Go Array本质上也是连续的一块内存,其内存视图和C中完全一致(Go Slice不一样,Slice是个包含Len, Cap, Data的结构体,其中Data的部分和C一致)。我们只需要欺骗Go,让他认为这是一个Go Array即可。

  1. 转成中间指针:unsafe.Pointer(arr)
  2. 强转成Go Array的指针:(*[1 << 31]C.int)(unsafe.Pointer(arr))。这里的*[1 << 31]C.int的含义是指向 [1 << 31]C.int的指针,其中1 << 31 只是用位运算表达一个很大的数而已,避免下一步的切面操作越界。
  3. 对数组做切片得到Go Slice:(*[1 << 31]C.int)(unsafe.Pointer(arr))[0:5:5],切面的参数[0:5:5]表示取[0, 5)对应的数据,且capacity为5。

这里的内存是C程序来维护的,Go Array/Slice只是C Array的View,并不维护C的内存。因此需要开发者自己维护好C数据的生命周期。如果不在乎性能的话,最简单的方法就是Deep Copy整个Slice。或者将析构的函数返回给接口的调用方,让其自行处理。

有聪明的小伙伴可能会想到通过 runtime.SetFinalizer 在Go Slice对象析构时,调用C的析构函数。这是有问题的,看下面的例子:

package main

import (
	"fmt"
	"runtime"
)

func main() {
	var b []int
	{
		a := []int{1, 2, 3}
		runtime.SetFinalizer(&a, func(interface{}) {
			fmt.Println("free a")
		})
		b = a
	}

	fmt.Println("111")
	runtime.GC()
	fmt.Printf("%v\n", b)
	fmt.Println("222")
}

// Output
// 111
// free a
// [1 2 3]
// 222

在第一次调用 runtime.GC() 的时候,a 就已经被释放了,此时如果 SetFinalizer 中调用了C的析构函数,会导致 b 访问异常的数据。原因是在Go中,Map和Slice都只是很小的Header结构体,b = a 语句其实是Copy了这个Header(并不是引用),后续析构的是 a 这个Header。

Passthrough Go Slice to C

将Go的Array/Slice透传给C就十分简单了。由于Go Array/Slice底层的数据都是连续的,获取第0个元素的地址即可。

  • easy: &arr[0]

Union / Enum

CGO通过 C.union_XXX 来使用union 类型,通过 C.enum_XXX 使用 enum 的类型,这点和struct相同。

对于定义好的enum类型,可以直接 C.XXX的方式直接访问枚举值。

由于Go中没有C Union的结构,Go会将C Union的数据转换为对应大小的字节数组。我们在使用的时候需要自行转换成对应的类型来读写。

package main

/*
#include <stdio.h>
enum Type {CHAR, INT, DOUBLE};
typedef union Value {
	char char_v;
	int int_v;
	double double_v;
} Value;

typedef struct Number {
	enum Type type;
	Value v;
} Number;

void print_number(struct Number *n) {
	switch (n->type) {
	case CHAR:
		printf("char %c\n", n->v.char_v);
		break;
	case INT:
		printf("int %d\n", n->v.int_v);
		break;
	case DOUBLE:
		printf("double %f\n", n->v.double_v);
		break;
	}
}
*/
import "C"
import (
	"fmt"
	"reflect"
	"unsafe"
)

func main() {
	n := &C.struct_Number{}
	// n.v.int_v = 100 // unsupported

	fmt.Printf("%v %v %v %v\n", reflect.TypeOf(n), reflect.TypeOf(n.v), reflect.TypeOf(n.v[0]), len(n.v))

	n._type = C.CHAR
	*(*C.char)(unsafe.Pointer(&(n.v))) = 'Z'
	C.print_number(n)

	n._type = C.INT
	*(*C.int)(unsafe.Pointer(&(n.v))) = 100
	C.print_number(n)

	n._type = C.DOUBLE
	*(*C.double)(unsafe.Pointer(&(n.v))) = 3.14
	C.print_number(n)
}

// Output
//  *main._Ctype_struct_Number main._Ctype_Value uint8 8
//  char Z
//  int 100
//  double 3.140000

这里需要注意:

  1. 名字为type的成员变量,需要用_type来访问。
  2. union Value类型,根据 reflect.TypeOf 的结果可以看出,它本质上是[8]uint8。这里的8C.sizeof_double的大小。
  3. [8]uint8 取地址之后,得到内存的首地址,再强转为我们期望的类型的指针,才可以使用。

Zero Copy

前面的数组的章节。我们访问C Array的时候,其实并没有做拷贝,而是直接访问的内存。但是对于string的操作,调用C.GoStringC.CString都做了深拷贝。

在Go中,string[]byte的转换,其实存在拷贝。如果我们可以让这两者之间的转换不存在拷贝的话。那么Go []byte和C char *也可以无拷贝的转换了。

这里需要注意的是,这里的[]byte对象的最后一个 byte 并没有规定必须是\0。因此如果直接将指针传给C程序的话,建议再传一个长度,否则很容易出现越界的错误(下面的例子CGO的部分可能会crash)。

例子:

package main

import (
	"fmt"
	"reflect"
	"unsafe"
)

/*
#include <stdio.h>
void print(char *s) {
	printf("%s\n", s);
}
*/
import "C"

func String2BytesUnsafe(s string) []byte {
	if len(s) == 0 {
		return nil
	}
	return (*[0x7fff0000]byte)(unsafe.Pointer(
		(*reflect.StringHeader)(unsafe.Pointer(&s)).Data),
	)[:len(s):len(s)]
}

func Bytes2StringUnsafe(bytes []byte) (out string) {
	if len(bytes) == 0 {
		return
	}
	sliceHeader := (*reflect.SliceHeader)(unsafe.Pointer(&bytes))
	header := (*reflect.StringHeader)(unsafe.Pointer(&out))
	header.Data = sliceHeader.Data
	header.Len = sliceHeader.Len
	return
}

func testString2BytesUnsafe(s string) {
	bytes := String2BytesUnsafe(s)
	fmt.Printf("String2Bytes Before: %v\n", s)
	bytes[0] = 'Q'
	fmt.Printf("String2Bytes After: %v\n", s)
}

func testBytes2StringUnsafe(bytes []byte) {
	s := Bytes2StringUnsafe(bytes)
	fmt.Printf("Bytes2String Before: %v\n", s)
	bytes[0] = 'Q'
	fmt.Printf("Bytes2String After: %v\n", s)
}

func main() {
	raw := []byte("Hello World!")
	testString2BytesUnsafe(string(raw))
	testBytes2StringUnsafe(raw)
	s1 := string(raw)
	C.print((*C.char)(unsafe.Pointer((*reflect.StringHeader)(unsafe.Pointer(&s1)).Data))) // may crash
}

// Output
// String2Bytes Before: Hello World!
// String2Bytes After: Qello World!
// Bytes2String Before: Hello World!
// Bytes2String After: Qello World!
// Qello World!

String -> Bytes

Go中的string其实是一个小对象。

type StringHeader struct {
	Data uintptr
	Len  int
}

string强转为reflect.StringHeader的结构,即可拿到Data的数据,再构造为[]byte即可(当然如果在CGO的话,直接用这个Data数据也行)。

Bytes -> String

和前面类似,构造reflect.StringHeader结构,再强转为string类型。