Last active
May 16, 2021 20:11
-
-
Save zephyrtronium/c38bc29c87cba6b9583bf145ef284b14 to your computer and use it in GitHub Desktop.
windows opencl demo
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| package main | |
| import ( | |
| "fmt" | |
| "log" | |
| "runtime" | |
| "syscall" | |
| "unsafe" | |
| "golang.org/x/sys/windows" | |
| ) | |
| type ocl struct { | |
| dll windows.Handle | |
| // procs | |
| clGetPlatformIDs uintptr | |
| clGetPlatformInfo uintptr | |
| } | |
| func loadcl() (*ocl, error) { | |
| hnd, err := windows.LoadLibraryEx("OpenCL", 0, windows.LOAD_LIBRARY_SEARCH_DEFAULT_DIRS) | |
| if err != nil { | |
| return nil, err | |
| } | |
| cl := ocl{dll: hnd} | |
| fin := func(d *ocl) { | |
| if err := d.release(); err != nil { | |
| log.Println("couldn't free opencl library during finalizer:", err) | |
| } | |
| d.dll = 0 | |
| } | |
| runtime.SetFinalizer(&cl, fin) | |
| if cl.clGetPlatformIDs, err = windows.GetProcAddress(cl.dll, "clGetPlatformIDs"); err != nil { | |
| cl.release() | |
| return nil, err | |
| } | |
| if cl.clGetPlatformInfo, err = windows.GetProcAddress(cl.dll, "clGetPlatformInfo"); err != nil { | |
| cl.release() | |
| return nil, err | |
| } | |
| return &cl, nil | |
| } | |
| func (cl *ocl) release() error { | |
| err := windows.FreeLibrary(cl.dll) | |
| *cl = ocl{} // zero all procs | |
| return err | |
| } | |
| func (cl *ocl) platforms() ([]uintptr, error) { | |
| var n uint32 | |
| r, _, errno := syscall.Syscall(cl.clGetPlatformIDs, 3, 0, 0, uintptr(unsafe.Pointer(&n))) | |
| if errno != 0 { | |
| log.Printf("platformvers: first call with r=%[2]d/%[3]d, n[4]=%d gave errno %[1]d: %[1]v", errno, int32(r), int32(r>>32), n) | |
| if r != 0 || n > 4 || n == 0 { | |
| return nil, errno | |
| } | |
| } | |
| pls := make([]uintptr, n) | |
| r, _, errno = syscall.Syscall(cl.clGetPlatformIDs, 3, uintptr(n), uintptr(unsafe.Pointer(&pls[0])), 0) | |
| if errno != 0 { | |
| log.Printf("platformvers: first call with r=%[2]d/%[3]d gave errno %d[1]: %[1]v", errno, int32(r), int32(r>>32)) | |
| return nil, errno | |
| } | |
| log.Printf("platformvers: second clGetPlatformIDs: ret %d/%d", int32(r), int32(r>>32)) | |
| return pls, nil | |
| } | |
| func (cl *ocl) platformstr(platform, info uintptr) (string, error) { | |
| var sz uintptr | |
| r, _, errno := syscall.Syscall6(cl.clGetPlatformInfo, 5, platform, info, 0, 0, uintptr(unsafe.Pointer(&sz)), 0) | |
| if errno != 0 { | |
| log.Printf("platformstr(%[2]x): first call with r=%[3]d gave errno %[1]d: %[1]v", errno, info, int32(r)) | |
| if r != 0 || sz > 1<<20 || sz == 0 { | |
| return "", errno | |
| } | |
| } | |
| if r != 0 { | |
| log.Printf("platformstr(%x): ret %d, sz %d", info, int32(r), sz) | |
| } | |
| if sz > 1<<20 { | |
| log.Fatal("I REFUSE TO ALLOCATE A STRING THAT BIG HOLY H*CK") | |
| } | |
| if sz == 0 { | |
| return "", nil | |
| } | |
| chr := make([]byte, sz) | |
| r, _, errno = syscall.Syscall6(cl.clGetPlatformInfo, 5, platform, info, sz, uintptr(unsafe.Pointer(&chr[0])), 0, 0) | |
| if errno != 0 { | |
| return "", errno | |
| } | |
| if r != 0 { | |
| log.Printf("platformstr(%x): ret %d, sz %d", info, int32(r), sz) | |
| } | |
| return string(chr[:sz-1]), nil | |
| } | |
| func main() { | |
| log.SetFlags(log.Lshortfile) | |
| cl, err := loadcl() | |
| if err != nil { | |
| log.Fatal(err) | |
| } | |
| pls, err := cl.platforms() | |
| if err != nil { | |
| log.Fatal(err) | |
| } | |
| log.Println("platform ids:", pls) | |
| for i, p := range pls { | |
| vendor, err := cl.platformstr(p, CL_PLATFORM_VENDOR) | |
| if err != nil { | |
| log.Printf("platform %d: couldn't get vendor: %v", i, err) | |
| continue | |
| } | |
| name, err := cl.platformstr(p, CL_PLATFORM_NAME) | |
| if err != nil { | |
| log.Printf("platform %d: couldn't get name: %v", i, err) | |
| continue | |
| } | |
| ver, err := cl.platformstr(p, CL_PLATFORM_VERSION) | |
| if err != nil { | |
| log.Printf("platform %d: couldn't get version: %v", i, err) | |
| continue | |
| } | |
| // ext, err := cl.platformstr(p, CL_PLATFORM_EXTENSIONS) | |
| // if err != nil { | |
| // log.Printf("platform %d: couldn't get extensions: %v", i, err) | |
| // continue | |
| // } | |
| fmt.Println("\nplatform", i) | |
| fmt.Println("\tvendor:", vendor) | |
| fmt.Println("\tname:", name) | |
| fmt.Println("\tversion:", ver) | |
| // fmt.Println("\textensions:", ext) | |
| } | |
| if err := cl.release(); err != nil { | |
| log.Fatal(err) | |
| } | |
| } | |
| // These come from cl.h in my CUDA 9.0 install. Other vendors are ?? | |
| const ( | |
| CL_PLATFORM_VERSION = 0x0901 | |
| CL_PLATFORM_NAME = 0x0902 | |
| CL_PLATFORM_VENDOR = 0x0903 | |
| CL_PLATFORM_EXTENSIONS = 0x0904 | |
| ) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| module demo | |
| go 1.16 | |
| require golang.org/x/sys v0.0.0-20210514084401-e8d321eab015 // indirect |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| golang.org/x/sys v0.0.0-20210514084401-e8d321eab015 h1:hZR0X1kPW+nwyJ9xRxqZk1vx5RUObAPBdKVvXPDUH/E= | |
| golang.org/x/sys v0.0.0-20210514084401-e8d321eab015/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment